mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-16 05:02:39 +00:00
Compare commits
11 Commits
v2.12.7
...
litellm_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5115b621c8 | ||
|
|
a924b49405 | ||
|
|
2d2d998811 | ||
|
|
0925b5fbd4 | ||
|
|
a02d8414ee | ||
|
|
c8abc4a115 | ||
|
|
cec37bff6a | ||
|
|
06d5d3971b | ||
|
|
ed287a2fc0 | ||
|
|
60857d1e73 | ||
|
|
bb5c22104e |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
@@ -1,79 +0,0 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
3
.github/workflows/pr-helm-chart-testing.yml
vendored
3
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,8 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
48
.github/workflows/pr-playwright-tests.yml
vendored
48
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,9 +22,6 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -470,3 +467,48 @@ jobs:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,10 +2,7 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -134,7 +134,6 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
|
||||
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./keys /app/keys
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add_user_preferences
|
||||
|
||||
Revision ID: 175ea04c7087
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-04 18:16:24.830873
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "175ea04c7087"
|
||||
down_revision = "d56ffa94ca32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("user_preferences", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "user_preferences")
|
||||
@@ -1,20 +1,20 @@
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -263,15 +263,9 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=source,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -50,7 +50,12 @@ def github_doc_sync(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
|
||||
@@ -18,7 +18,12 @@ def github_group_sync(
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
|
||||
@@ -50,7 +50,12 @@ def gmail_doc_sync(
|
||||
already populated.
|
||||
"""
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
gmail_connector.load_credentials(credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
|
||||
@@ -295,7 +295,12 @@ def gdrive_doc_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
|
||||
@@ -391,7 +391,12 @@ def gdrive_group_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
admin_service = get_admin_service(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
@@ -24,7 +24,12 @@ def jira_doc_sync(
|
||||
jira_connector = JiraConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
jira_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -119,8 +119,13 @@ def jira_group_sync(
|
||||
if not jira_base_url:
|
||||
raise ValueError("No jira_base_url found in connector config")
|
||||
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_client = build_jira_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
credentials=credential_json,
|
||||
jira_base=jira_base_url,
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,11 @@ def get_any_salesforce_client_for_doc_id(
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
credential_json = (
|
||||
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if first_cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
@@ -158,7 +162,11 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
|
||||
@@ -24,7 +24,12 @@ def sharepoint_doc_sync(
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
sharepoint_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -25,7 +25,12 @@ def sharepoint_group_sync(
|
||||
|
||||
# Create SharePoint connector instance and load credentials
|
||||
connector = SharepointConnector(**connector_config)
|
||||
connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
connector.load_credentials(credential_json)
|
||||
|
||||
if not connector.msal_app:
|
||||
raise RuntimeError("MSAL app not initialized in connector")
|
||||
|
||||
@@ -151,9 +151,14 @@ def slack_doc_sync(
|
||||
tenant_id = get_current_tenant_id()
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -63,9 +63,14 @@ def slack_group_sync(
|
||||
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,12 @@ def teams_doc_sync(
|
||||
teams_connector = TeamsConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
teams_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
teams_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -150,9 +151,12 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -109,9 +109,7 @@ async def _make_billing_request(
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT, follow_redirects=True
|
||||
) as client:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
|
||||
@@ -270,7 +270,11 @@ def confluence_oauth_accessible_resources(
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
credential_dict = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
@@ -337,7 +341,12 @@ def confluence_oauth_finalize(
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
existing_credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credential_json: dict[str, Any] = dict(existing_credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -44,14 +40,6 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -93,18 +81,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -113,11 +89,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found in cache or DB.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
# No license = community edition, disable EE features
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
@@ -177,7 +177,7 @@ async def forward_to_control_plane(
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
|
||||
@@ -12,14 +12,12 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -47,23 +45,6 @@ def list_user_groups(
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
|
||||
@@ -76,18 +76,6 @@ class UserGroup(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
)
|
||||
|
||||
|
||||
class UserGroupCreate(BaseModel):
|
||||
name: str
|
||||
user_ids: list[UUID]
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.db.models import OAuthUserToken
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,7 +34,10 @@ class OAuthTokenManager:
|
||||
if not user_token:
|
||||
return None
|
||||
|
||||
token_data = user_token.token_data
|
||||
if not user_token.token_data:
|
||||
return None
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
# Check if token is expired
|
||||
if OAuthTokenManager.is_token_expired(token_data):
|
||||
@@ -51,7 +55,10 @@ class OAuthTokenManager:
|
||||
|
||||
def refresh_token(self, user_token: OAuthUserToken) -> str:
|
||||
"""Refresh access token using refresh token"""
|
||||
token_data = user_token.token_data
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
@@ -153,3 +160,11 @@ class OAuthTokenManager:
|
||||
separator = "&" if "?" in oauth_config.authorization_url else "?"
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(token_data, SensitiveValue):
|
||||
return token_data.get_value(apply_mask=False)
|
||||
return token_data
|
||||
|
||||
@@ -60,7 +60,6 @@ from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
@@ -111,7 +110,6 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -274,22 +272,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
"""Raise HTTPException(402) if adding users would exceed the seat limit.
|
||||
|
||||
No-op for multi-tenant or CE deployments.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -419,12 +401,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
existing = get_user_by_email(user_create.email, sync_db)
|
||||
if existing is None:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
@@ -634,10 +610,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
|
||||
@@ -12,7 +12,6 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -20,14 +19,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -57,17 +54,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -131,24 +117,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -163,21 +132,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -190,35 +145,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -226,8 +158,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -244,12 +175,6 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -27,8 +27,8 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -60,28 +60,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -393,7 +371,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
persona: Persona | None,
|
||||
memories: list[str] | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
@@ -477,12 +455,7 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -511,7 +484,7 @@ def run_llm_loop(
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
@@ -665,7 +638,7 @@ def run_llm_loop(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
|
||||
@@ -471,7 +471,7 @@ def handle_stream_message_objects(
|
||||
# Filter chat_history to only messages after the cutoff
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
memories = get_memories(user, db_session)
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
@@ -480,7 +480,7 @@ def handle_stream_message_objects(
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
@@ -667,7 +667,7 @@ def handle_stream_message_objects(
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -12,7 +13,6 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
@@ -25,6 +25,7 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -52,7 +53,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: str,
|
||||
token_counter: Callable[[str], int],
|
||||
files: list[FileDescriptor] | None = None,
|
||||
memories: list[str] | None = None,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate reserved token count for system prompt and user files.
|
||||
@@ -66,7 +67,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: Custom agent system prompt (can be empty string)
|
||||
token_counter: Function that counts tokens in text
|
||||
files: List of file descriptors from the chat message (optional)
|
||||
memories: List of memory strings (optional)
|
||||
user_memory_context: User memory context (optional)
|
||||
|
||||
Returns:
|
||||
Total reserved token count
|
||||
@@ -77,7 +78,7 @@ def calculate_reserved_tokens(
|
||||
fake_system_prompt = build_system_prompt(
|
||||
base_system_prompt=base_system_prompt,
|
||||
datetime_aware=True,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
tools=None,
|
||||
should_cite_documents=True,
|
||||
include_all_guidance=True,
|
||||
@@ -133,7 +134,7 @@ def build_reminder_message(
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
memories: list[str] | None = None,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
@@ -157,14 +158,15 @@ def build_system_prompt(
|
||||
)
|
||||
|
||||
company_context = get_company_context()
|
||||
if company_context or memories:
|
||||
system_prompt += USER_INFO_HEADER
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
)
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if memories:
|
||||
system_prompt += "\n".join(
|
||||
"- " + memory.strip() for memory in memories if memory.strip()
|
||||
)
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -75,7 +75,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
@@ -900,9 +900,6 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Limit on number of users a free trial tenant can invite (cloud only)
|
||||
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
|
||||
|
||||
# Security and authentication
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
|
||||
@@ -158,17 +158,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -446,9 +435,6 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -65,7 +65,9 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
if credential.credential_json is None:
|
||||
return {}
|
||||
return credential.credential_json.get_value(apply_mask=False)
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
@@ -81,7 +83,7 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -118,7 +118,12 @@ def instantiate_connector(
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credentials = connector.load_credentials(credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
@@ -32,8 +32,6 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -47,13 +45,9 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=10,
|
||||
total=5,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -67,24 +61,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -128,8 +106,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -164,8 +142,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -216,8 +194,7 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
# in ~2 minutes
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -236,14 +213,11 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
@@ -270,6 +270,8 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -297,14 +299,21 @@ def alter_credential(
|
||||
|
||||
credential.name = name
|
||||
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
# Get existing credential_json and merge with new values
|
||||
existing_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential.credential_json = { # type: ignore[assignment]
|
||||
**existing_json,
|
||||
**credential_json,
|
||||
}
|
||||
|
||||
credential.user_id = user.id
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -318,10 +327,12 @@ def update_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_data.credential_json
|
||||
credential.user_id = user.id
|
||||
credential.credential_json = credential_data.credential_json # type: ignore[assignment]
|
||||
credential.user_id = user.id if user is not None else None
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -335,8 +346,10 @@ def update_credential_json(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -346,7 +359,7 @@ def backend_update_credential_json(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""This should not be used in any flows involving the frontend or users"""
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -441,7 +454,12 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
credential_json_value = (
|
||||
first_credential.credential_json.get_value(apply_mask=False)
|
||||
if first_credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if credential_json_value != {} or first_credential.user is not None:
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
@@ -477,8 +495,13 @@ def delete_service_account_credentials(
|
||||
) -> None:
|
||||
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if (
|
||||
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
and credential.source == source
|
||||
):
|
||||
db_session.delete(credential)
|
||||
|
||||
@@ -111,7 +111,7 @@ def update_federated_connector_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token = token
|
||||
existing_token.token = token # type: ignore[assignment]
|
||||
existing_token.expires_at = expires_at
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
@@ -267,7 +267,13 @@ def update_federated_connector(
|
||||
# Use provided credentials if updating them, otherwise use existing credentials
|
||||
# This is needed to instantiate the connector for config validation when only config is being updated
|
||||
creds_to_use = (
|
||||
credentials if credentials is not None else federated_connector.credentials
|
||||
credentials
|
||||
if credentials is not None
|
||||
else (
|
||||
federated_connector.credentials.get_value(apply_mask=False)
|
||||
if federated_connector.credentials
|
||||
else {}
|
||||
)
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
@@ -278,7 +284,7 @@ def update_federated_connector(
|
||||
raise ValueError(
|
||||
f"Invalid credentials for federated connector source: {federated_connector.source}"
|
||||
)
|
||||
federated_connector.credentials = credentials
|
||||
federated_connector.credentials = credentials # type: ignore[assignment]
|
||||
|
||||
if config is not None:
|
||||
# Validate config using connector-specific validation
|
||||
|
||||
@@ -109,38 +109,45 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
"""
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -225,7 +232,8 @@ def upsert_llm_provider(
|
||||
custom_config = custom_config or None
|
||||
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
@@ -421,7 +429,7 @@ def fetch_existing_models(
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
flow_type_filter: list[LLMModelFlowType],
|
||||
flow_types: list[LLMModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
@@ -429,27 +437,30 @@ def fetch_existing_llm_providers(
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flow_type_filter: List of flow types to filter by, empty list for no filter
|
||||
flow_types: List of flow types to filter by
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
stmt = select(LLMProviderModel)
|
||||
|
||||
if flow_type_filter:
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
|
||||
.distinct()
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
@@ -712,15 +723,13 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
else:
|
||||
# Add new model - all models from GitHub config are visible
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
new_model = ModelConfiguration(
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_config.name,
|
||||
supported_flows=[LLMModelFlowType.CHAT],
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
name=model_config.name,
|
||||
display_name=model_config.display_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -204,6 +205,21 @@ def remove_user_from_mcp_server(
|
||||
|
||||
|
||||
# MCPConnectionConfig operations
|
||||
def extract_connection_data(
|
||||
config: MCPConnectionConfig | None, apply_mask: bool = False
|
||||
) -> MCPConnectionData:
|
||||
"""Extract MCPConnectionData from a connection config, with proper typing.
|
||||
|
||||
This helper encapsulates the cast from the JSON column's dict[str, Any]
|
||||
to the typed MCPConnectionData structure.
|
||||
"""
|
||||
if config is None or config.config is None:
|
||||
return MCPConnectionData(headers={})
|
||||
if isinstance(config.config, SensitiveValue):
|
||||
return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask))
|
||||
return cast(MCPConnectionData, config.config)
|
||||
|
||||
|
||||
def get_connection_config_by_id(
|
||||
config_id: int, db_session: Session
|
||||
) -> MCPConnectionConfig:
|
||||
@@ -269,7 +285,7 @@ def update_connection_config(
|
||||
config = get_connection_config_by_id(config_id, db_session)
|
||||
|
||||
if config_data is not None:
|
||||
config.config = config_data
|
||||
config.config = config_data # type: ignore[assignment]
|
||||
# Force SQLAlchemy to detect the change by marking the field as modified
|
||||
flag_modified(config, "config")
|
||||
|
||||
@@ -287,7 +303,7 @@ def upsert_user_connection_config(
|
||||
existing_config = get_user_connection_config(server_id, user_email, db_session)
|
||||
|
||||
if existing_config:
|
||||
existing_config.config = config_data
|
||||
existing_config.config = config_data # type: ignore[assignment]
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return existing_config
|
||||
else:
|
||||
|
||||
@@ -1,22 +1,111 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> list[str]:
|
||||
class UserInfo(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
|
||||
class UserMemoryContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
user_info: UserInfo
|
||||
user_preferences: str | None = None
|
||||
memories: tuple[str, ...] = ()
|
||||
|
||||
def as_formatted_list(self) -> list[str]:
|
||||
"""Returns combined list of user info, preferences, and memories."""
|
||||
result = []
|
||||
if self.user_info.name:
|
||||
result.append(f"User's name: {self.user_info.name}")
|
||||
if self.user_info.role:
|
||||
result.append(f"User's role: {self.user_info.role}")
|
||||
if self.user_info.email:
|
||||
result.append(f"User's email: {self.user_info.email}")
|
||||
if self.user_preferences:
|
||||
result.append(f"User preferences: {self.user_preferences}")
|
||||
result.extend(self.memories)
|
||||
return result
|
||||
|
||||
def as_formatted_prompt(self) -> str:
|
||||
"""Returns structured prompt sections for the system prompt."""
|
||||
has_basic_info = (
|
||||
self.user_info.name or self.user_info.email or self.user_info.role
|
||||
)
|
||||
if not has_basic_info and not self.user_preferences and not self.memories:
|
||||
return ""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
|
||||
if self.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=self.user_info.name or "",
|
||||
user_email=self.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
|
||||
)
|
||||
|
||||
if self.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
return "".join(sections)
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
if not user.use_memories:
|
||||
return []
|
||||
return UserMemoryContext(user_info=UserInfo())
|
||||
|
||||
user_info = [
|
||||
f"User's name: {user.personal_name}" if user.personal_name else "",
|
||||
f"User's role: {user.personal_role}" if user.personal_role else "",
|
||||
f"User's email: {user.email}" if user.email else "",
|
||||
]
|
||||
user_info = UserInfo(
|
||||
name=user.personal_name,
|
||||
role=user.personal_role,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
user_preferences = None
|
||||
if user.user_preferences:
|
||||
user_preferences = user.user_preferences
|
||||
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user.id)
|
||||
select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
|
||||
return user_info + memories
|
||||
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
|
||||
|
||||
return UserMemoryContext(
|
||||
user_info=user_info,
|
||||
user_preferences=user_preferences,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
@@ -95,10 +95,10 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
@@ -122,18 +122,35 @@ class EncryptedString(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | None, dialect: Dialect # noqa: ARG002
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw strings and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
return encrypt_string_to_bytes(value)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> str | None:
|
||||
) -> SensitiveValue[str] | None:
|
||||
if value is not None:
|
||||
return decrypt_bytes_to_string(value)
|
||||
return value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=False,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
@@ -142,20 +159,38 @@ class EncryptedJson(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: dict | None, dialect: Dialect # noqa: ARG002
|
||||
self,
|
||||
value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None,
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> dict | None:
|
||||
) -> SensitiveValue[dict[str, Any]] | None:
|
||||
if value is not None:
|
||||
json_str = decrypt_bytes_to_string(value)
|
||||
return json.loads(json_str)
|
||||
return value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -216,6 +251,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
@@ -1755,7 +1791,9 @@ class Credential(Base):
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson()
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
@@ -1793,7 +1831,9 @@ class FederatedConnector(Base):
|
||||
source: Mapped[FederatedConnectorSource] = mapped_column(
|
||||
Enum(FederatedConnectorSource, native_enum=False)
|
||||
)
|
||||
credentials: Mapped[dict[str, str]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB(), default=dict, nullable=False, server_default="{}"
|
||||
)
|
||||
@@ -1820,7 +1860,9 @@ class FederatedConnectorOAuthToken(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
@@ -1964,7 +2006,9 @@ class SearchSettings(Base):
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
if self.cloud_provider is None or self.cloud_provider.api_key is None:
|
||||
return None
|
||||
return self.cloud_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
@property
|
||||
def large_chunks_enabled(self) -> bool:
|
||||
@@ -2726,7 +2770,9 @@ class LLMProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider: Mapped[str] = mapped_column(String)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# custom configs that should be passed to the LLM provider at inference time
|
||||
@@ -2879,7 +2925,7 @@ class CloudEmbeddingProvider(Base):
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString())
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -2898,7 +2944,9 @@ class InternetSearchProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
@@ -2920,7 +2968,9 @@ class InternetContentProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
config: Mapped[WebContentProviderConfig | None] = mapped_column(
|
||||
PydanticType(WebContentProviderConfig), nullable=True
|
||||
)
|
||||
@@ -3064,8 +3114,12 @@ class OAuthConfig(Base):
|
||||
token_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
# Client credentials (encrypted)
|
||||
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_id: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
client_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
|
||||
# Optional configurations
|
||||
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
@@ -3112,7 +3166,9 @@ class OAuthUserToken(Base):
|
||||
# "expires_at": 1234567890, # Unix timestamp, optional
|
||||
# "scope": "repo user" # Optional
|
||||
# }
|
||||
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -3445,9 +3501,15 @@ class SlackBot(Base):
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
app_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
user_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
|
||||
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
|
||||
"SlackChannelConfig",
|
||||
@@ -3468,7 +3530,9 @@ class DiscordBotConfig(Base):
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
@@ -3624,7 +3688,9 @@ class KVStore(Base):
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
encrypted_value: Mapped[JSON_ro] = mapped_column(EncryptedJson(), nullable=True)
|
||||
encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
@@ -4344,7 +4410,7 @@ class MCPConnectionConfig(Base):
|
||||
# "registration_access_token": "<token>", # For managing registration
|
||||
# "registration_client_uri": "<uri>", # For managing registration
|
||||
# }
|
||||
config: Mapped[MCPConnectionData] = mapped_column(
|
||||
config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False, default=dict
|
||||
)
|
||||
|
||||
|
||||
@@ -87,13 +87,13 @@ def update_oauth_config(
|
||||
if token_url is not None:
|
||||
oauth_config.token_url = token_url
|
||||
if clear_client_id:
|
||||
oauth_config.client_id = ""
|
||||
oauth_config.client_id = "" # type: ignore[assignment]
|
||||
elif client_id is not None:
|
||||
oauth_config.client_id = client_id
|
||||
oauth_config.client_id = client_id # type: ignore[assignment]
|
||||
if clear_client_secret:
|
||||
oauth_config.client_secret = ""
|
||||
oauth_config.client_secret = "" # type: ignore[assignment]
|
||||
elif client_secret is not None:
|
||||
oauth_config.client_secret = client_secret
|
||||
oauth_config.client_secret = client_secret # type: ignore[assignment]
|
||||
if scopes is not None:
|
||||
oauth_config.scopes = scopes
|
||||
if additional_params is not None:
|
||||
@@ -154,7 +154,7 @@ def upsert_user_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token_data = token_data
|
||||
existing_token.token_data = token_data # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
else:
|
||||
|
||||
@@ -43,9 +43,9 @@ def update_slack_bot(
|
||||
# update the app
|
||||
slack_bot.name = name
|
||||
slack_bot.enabled = enabled
|
||||
slack_bot.bot_token = bot_token
|
||||
slack_bot.app_token = app_token
|
||||
slack_bot.user_token = user_token
|
||||
slack_bot.bot_token = bot_token # type: ignore[assignment]
|
||||
slack_bot.app_token = app_token # type: ignore[assignment]
|
||||
slack_bot.user_token = user_token # type: ignore[assignment]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ def update_user_personalization(
|
||||
personal_role: str | None,
|
||||
use_memories: bool,
|
||||
memories: list[str],
|
||||
user_preferences: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
@@ -169,6 +170,7 @@ def update_user_personalization(
|
||||
personal_name=personal_name,
|
||||
personal_role=personal_role,
|
||||
use_memories=use_memories,
|
||||
user_preferences=user_preferences,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -73,7 +73,8 @@ def _apply_search_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
|
||||
def upsert_web_search_provider(
|
||||
@@ -228,7 +229,8 @@ def _apply_content_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
|
||||
def upsert_web_content_provider(
|
||||
|
||||
@@ -119,7 +119,16 @@ def get_federated_retrieval_functions(
|
||||
federated_retrieval_infos_slack = []
|
||||
|
||||
# Use user_token if available, otherwise fall back to bot_token
|
||||
access_token = tenant_slack_bot.user_token or tenant_slack_bot.bot_token
|
||||
# Unwrap SensitiveValue for backend API calls
|
||||
access_token = (
|
||||
tenant_slack_bot.user_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.user_token
|
||||
else (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if not tenant_slack_bot.user_token:
|
||||
logger.warning(
|
||||
f"Using bot_token for Slack search (limited functionality): {tenant_slack_bot.name}"
|
||||
@@ -138,7 +147,12 @@ def get_federated_retrieval_functions(
|
||||
)
|
||||
|
||||
# Capture variables by value to avoid lambda closure issues
|
||||
bot_token = tenant_slack_bot.bot_token
|
||||
# Unwrap SensitiveValue for backend API calls
|
||||
bot_token = (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else ""
|
||||
)
|
||||
|
||||
# Use connector config for channel filtering (guaranteed to exist at this point)
|
||||
connector_entities = slack_federated_connector_config
|
||||
@@ -252,11 +266,11 @@ def get_federated_retrieval_functions(
|
||||
|
||||
connector = get_federated_connector(
|
||||
oauth_token.federated_connector.source,
|
||||
oauth_token.federated_connector.credentials,
|
||||
oauth_token.federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
|
||||
# Capture variables by value to avoid lambda closure issues
|
||||
access_token = oauth_token.token
|
||||
access_token = oauth_token.token.get_value(apply_mask=False)
|
||||
|
||||
def create_retrieval_function(
|
||||
conn: FederatedConnector,
|
||||
|
||||
@@ -43,7 +43,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
obj = db_session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
obj.encrypted_value = encrypted_val # type: ignore[assignment]
|
||||
else:
|
||||
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
|
||||
db_session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
@@ -73,7 +73,8 @@ class PgRedisKVStore(KeyValueStore):
|
||||
if obj.value is not None:
|
||||
value = obj.value
|
||||
elif obj.encrypted_value is not None:
|
||||
value = obj.encrypted_value
|
||||
# Unwrap SensitiveValue - this is internal backend use
|
||||
value = obj.encrypted_value.get_value(apply_mask=False)
|
||||
else:
|
||||
value = None
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -44,11 +48,13 @@ from onyx.llm.well_known_providers.constants import (
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_env_lock = threading.Lock()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import HTTPHandler
|
||||
@@ -378,23 +384,29 @@ class LitellmLLM(LLM):
|
||||
if "api_key" not in passthrough_kwargs:
|
||||
passthrough_kwargs["api_key"] = self._api_key or None
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
env_ctx = (
|
||||
temporary_env_and_lock(self._custom_config)
|
||||
if self._custom_config
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# for break pointing
|
||||
@@ -475,22 +487,53 @@ class LitellmLLM(LLM):
|
||||
client = HTTPHandler(timeout=timeout_override or self._timeout)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
if self._custom_config:
|
||||
# When custom_config is set, env vars are temporarily injected
|
||||
# under a global lock. Using stream=True here means the lock is
|
||||
# only held during connection setup (not the full inference).
|
||||
# The chunks are then collected outside the lock and reassembled
|
||||
# into a single ModelResponse via stream_chunk_builder.
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
|
||||
stream_response = cast(
|
||||
LiteLLMCustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
chunks = list(stream_response)
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
stream_chunk_builder(chunks),
|
||||
)
|
||||
else:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
|
||||
model_response = from_litellm_model_response(response)
|
||||
|
||||
@@ -581,3 +624,29 @@ class LitellmLLM(LLM):
|
||||
finally:
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
|
||||
"""
|
||||
Temporarily sets the environment variables to the given values.
|
||||
Code path is locked while the environment variables are set.
|
||||
Then cleans up the environment and frees the lock.
|
||||
"""
|
||||
with _env_lock:
|
||||
logger.debug("Acquired lock in temporary_env_and_lock")
|
||||
# Store original values (None if key didn't exist)
|
||||
original_values: dict[str, str | None] = {
|
||||
key: os.environ.get(key) for key in env_variables
|
||||
}
|
||||
try:
|
||||
os.environ.update(env_variables)
|
||||
yield
|
||||
finally:
|
||||
for key, original_value in original_values.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None) # Remove if it didn't exist before
|
||||
else:
|
||||
os.environ[key] = original_value # Restore original value
|
||||
|
||||
logger.debug("Released lock in temporary_env_and_lock")
|
||||
|
||||
@@ -4,6 +4,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -36,4 +37,8 @@ def get_bot_token() -> str | None:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot token from database: {e}")
|
||||
return None
|
||||
return config.bot_token if config else None
|
||||
if config and config.bot_token:
|
||||
if isinstance(config.bot_token, SensitiveValue):
|
||||
return config.bot_token.get_value(apply_mask=False)
|
||||
return config.bot_token
|
||||
return None
|
||||
|
||||
@@ -592,8 +592,11 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -605,6 +608,7 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -1,163 +1,29 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
return result
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
return f"*{text}*\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -176,7 +42,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines) + "\n"
|
||||
return "\n".join(lines)
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -198,73 +64,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
return f"```\n{code}\n```\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
return f"{text}\n"
|
||||
|
||||
@@ -18,18 +18,15 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -44,51 +41,6 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -205,20 +157,6 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -269,7 +207,6 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -294,16 +231,6 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -358,7 +285,6 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -216,14 +216,10 @@ class SlackbotHandler:
|
||||
- If the tokens have changed, close the existing socket client and reconnect.
|
||||
- If the tokens are new, warm up the model and start a new socket client.
|
||||
"""
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token,
|
||||
app_token=bot.app_token,
|
||||
)
|
||||
tenant_bot_pair = (tenant_id, bot.id)
|
||||
|
||||
# If the tokens are missing or empty, close the socket client and remove them.
|
||||
if not slack_bot_tokens:
|
||||
if not bot.bot_token or not bot.app_token:
|
||||
logger.debug(
|
||||
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
|
||||
)
|
||||
@@ -233,6 +229,11 @@ class SlackbotHandler:
|
||||
del self.slack_bot_tokens[tenant_bot_pair]
|
||||
return
|
||||
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token.get_value(apply_mask=False),
|
||||
app_token=bot.app_token.get_value(apply_mask=False),
|
||||
)
|
||||
|
||||
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
|
||||
tokens_changed = (
|
||||
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
|
||||
|
||||
@@ -25,9 +25,6 @@ You can use Markdown tables to format your responses for data, lists, and other
|
||||
""".lstrip()
|
||||
|
||||
|
||||
# Section for information about the user if provided such as their name, role, memories, etc.
|
||||
USER_INFO_HEADER = "\n\n# User Information\n"
|
||||
|
||||
COMPANY_NAME_BLOCK = """
|
||||
The user is at an organization called `{company_name}`.
|
||||
"""
|
||||
|
||||
@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -92,7 +92,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -404,12 +403,13 @@ def check_drive_tokens(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AuthStatus:
|
||||
db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if (
|
||||
not db_credentials
|
||||
or DB_CREDENTIALS_DICT_TOKEN_KEY not in db_credentials.credential_json
|
||||
):
|
||||
if not db_credentials or not db_credentials.credential_json:
|
||||
return AuthStatus(authenticated=False)
|
||||
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
|
||||
credential_json = db_credentials.credential_json.get_value(apply_mask=False)
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY not in credential_json:
|
||||
return AuthStatus(authenticated=False)
|
||||
token_json_str = str(credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
google_drive_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
@@ -557,43 +557,6 @@ def _normalize_file_names_for_backwards_compatibility(
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
require_editable: bool,
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=require_editable,
|
||||
)
|
||||
if has_requested_access:
|
||||
return cc_pair
|
||||
|
||||
# Special case: global curators should be able to manage files
|
||||
# for public file connectors even when they are not the creator.
|
||||
if (
|
||||
require_editable
|
||||
and user.role == UserRole.GLOBAL_CURATOR
|
||||
and cc_pair.access_type == AccessType.PUBLIC
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -605,7 +568,7 @@ def upload_files_api(
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
@@ -618,13 +581,6 @@ def list_connector_files(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=False,
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
@@ -674,7 +630,7 @@ def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
@@ -692,13 +648,12 @@ def update_connector_files(
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
# and validate user permissions for file management.
|
||||
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=True,
|
||||
)
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
|
||||
@@ -346,10 +346,17 @@ def update_credential_from_model(
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
# Get credential_json value - use masking for API responses
|
||||
credential_json_value = (
|
||||
updated_credential.credential_json.get_value(apply_mask=True)
|
||||
if updated_credential.credential_json
|
||||
else {}
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
source=updated_credential.source,
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
credential_json=credential_json_value,
|
||||
user_id=updated_credential.user_id,
|
||||
name=updated_credential.name,
|
||||
admin_public=updated_credential.admin_public,
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.federated.models import FederatedConnectorStatus
|
||||
from onyx.server.utils import mask_credential_dict
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -145,13 +144,21 @@ class CredentialSnapshot(CredentialBase):
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
# Get the credential_json value with appropriate masking
|
||||
if credential.credential_json is None:
|
||||
credential_json_value: dict[str, Any] = {}
|
||||
elif MASK_CREDENTIAL_PREFIX:
|
||||
credential_json_value = credential.credential_json.get_value(
|
||||
apply_mask=True
|
||||
)
|
||||
else:
|
||||
credential_json_value = credential.credential_json.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=(
|
||||
mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX and credential.credential_json
|
||||
else credential.credential_json
|
||||
),
|
||||
credential_json=credential_json_value,
|
||||
user_id=credential.user_id,
|
||||
user_email=credential.user.email if credential.user else None,
|
||||
admin_public=credential.admin_public,
|
||||
|
||||
@@ -88,7 +88,7 @@ SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes")
|
||||
# Container image for sandbox pods
|
||||
# Should include Next.js template, opencode CLI, and demo_data zip
|
||||
SANDBOX_CONTAINER_IMAGE = os.environ.get(
|
||||
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.2"
|
||||
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.3"
|
||||
)
|
||||
|
||||
# S3 bucket for sandbox file storage (snapshots, knowledge files, uploads)
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
# Sandbox Container Image
|
||||
|
||||
This directory contains the Dockerfile and resources for building the Onyx Craft sandbox container image.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
docker/
|
||||
├── Dockerfile # Main container image definition
|
||||
├── demo_data.zip # Demo data (extracted to /workspace/demo_data)
|
||||
├── templates/
|
||||
│ └── outputs/ # Web app scaffold template (Next.js)
|
||||
├── initial-requirements.txt # Python packages pre-installed in sandbox
|
||||
├── generate_agents_md.py # Script to generate AGENTS.md for sessions
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Building the Image
|
||||
|
||||
The sandbox image must be built for **amd64** architecture since our Kubernetes cluster runs on x86_64 nodes.
|
||||
|
||||
### Build for amd64 only (fastest)
|
||||
|
||||
```bash
|
||||
cd backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
docker build --platform linux/amd64 -t onyxdotapp/sandbox:v0.1.x .
|
||||
docker push onyxdotapp/sandbox:v0.1.x
|
||||
```
|
||||
|
||||
### Build multi-arch (recommended for flexibility)
|
||||
|
||||
```bash
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t onyxdotapp/sandbox:v0.1.x \
|
||||
--push .
|
||||
```
|
||||
|
||||
### Update the `latest` tag
|
||||
|
||||
After pushing a versioned tag, update `latest`:
|
||||
|
||||
```bash
|
||||
docker tag onyxdotapp/sandbox:v0.1.x onyxdotapp/sandbox:latest
|
||||
docker push onyxdotapp/sandbox:latest
|
||||
```
|
||||
|
||||
Or with buildx:
|
||||
|
||||
```bash
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t onyxdotapp/sandbox:v0.1.x \
|
||||
-t onyxdotapp/sandbox:latest \
|
||||
--push .
|
||||
```
|
||||
|
||||
## Deploying a New Version
|
||||
|
||||
1. **Build and push** the new image (see above)
|
||||
|
||||
2. **Update the ConfigMap** in `cloud-deployment-yamls/danswer/configmap/env-configmap.yaml`:
|
||||
```yaml
|
||||
SANDBOX_CONTAINER_IMAGE: "onyxdotapp/sandbox:v0.1.x"
|
||||
```
|
||||
|
||||
3. **Apply the ConfigMap**:
|
||||
```bash
|
||||
kubectl apply -f configmap/env-configmap.yaml
|
||||
```
|
||||
|
||||
4. **Restart the API server** to pick up the new config:
|
||||
```bash
|
||||
kubectl rollout restart deployment/api-server -n danswer
|
||||
```
|
||||
|
||||
5. **Delete existing sandbox pods** (they will be recreated with the new image):
|
||||
```bash
|
||||
kubectl delete pods -n onyx-sandboxes -l app.kubernetes.io/component=sandbox
|
||||
```
|
||||
|
||||
## What's Baked Into the Image
|
||||
|
||||
- **Base**: `node:20-slim` (Debian-based)
|
||||
- **Demo data**: `/workspace/demo_data/` - sample files for demo sessions
|
||||
- **Templates**: `/workspace/templates/outputs/` - Next.js web app scaffold
|
||||
- **Python venv**: `/workspace/.venv/` with packages from `initial-requirements.txt`
|
||||
- **OpenCode CLI**: Installed in `/home/sandbox/.opencode/bin/`
|
||||
|
||||
## Runtime Directory Structure
|
||||
|
||||
When a session is created, the following structure is set up in the pod:
|
||||
|
||||
```
|
||||
/workspace/
|
||||
├── demo_data/ # Baked into image
|
||||
├── files/ # Mounted volume, synced from S3
|
||||
├── templates/ # Baked into image
|
||||
└── sessions/
|
||||
└── $session_id/
|
||||
├── files/ # Symlink to /workspace/demo_data or /workspace/files
|
||||
├── outputs/ # Copied from templates, contains web app
|
||||
├── attachments/ # User-uploaded files
|
||||
├── org_info/ # Demo persona info (if demo mode)
|
||||
├── AGENTS.md # Instructions for the AI agent
|
||||
└── opencode.json # OpenCode configuration
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Verify image exists on Docker Hub
|
||||
|
||||
```bash
|
||||
curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags" | jq '.results[].name'
|
||||
```
|
||||
|
||||
### Check what image a pod is using
|
||||
|
||||
```bash
|
||||
kubectl get pod <pod-name> -n onyx-sandboxes -o jsonpath='{.spec.containers[?(@.name=="sandbox")].image}'
|
||||
```
|
||||
Binary file not shown.
@@ -349,7 +349,11 @@ class SessionManager:
|
||||
return LLMProviderConfig(
|
||||
provider=default_model.llm_provider.provider,
|
||||
model_name=default_model.name,
|
||||
api_key=default_model.llm_provider.api_key,
|
||||
api_key=(
|
||||
default_model.llm_provider.api_key.get_value(apply_mask=False)
|
||||
if default_model.llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=default_model.llm_provider.api_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
|
||||
from onyx.db.mcp import delete_connection_config
|
||||
from onyx.db.mcp import delete_mcp_server
|
||||
from onyx.db.mcp import delete_user_connection_configs_for_server
|
||||
from onyx.db.mcp import extract_connection_data
|
||||
from onyx.db.mcp import get_all_mcp_servers
|
||||
from onyx.db.mcp import get_connection_config_by_id
|
||||
from onyx.db.mcp import get_mcp_server_by_id
|
||||
@@ -79,6 +80,7 @@ from onyx.server.features.tool.models import ToolSnapshot
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import discover_mcp_tools
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import initialize_mcp_client
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import log_exception_group
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -143,7 +145,8 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
tokens_raw = config.config.get(MCPOAuthKeys.TOKENS.value)
|
||||
config_data = extract_connection_data(config)
|
||||
tokens_raw = config_data.get(MCPOAuthKeys.TOKENS.value)
|
||||
if tokens_raw:
|
||||
return OAuthToken.model_validate(tokens_raw)
|
||||
return None
|
||||
@@ -151,14 +154,14 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
config.config[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
|
||||
cfg_headers = {
|
||||
config_data = extract_connection_data(config)
|
||||
config_data[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
|
||||
config_data["headers"] = {
|
||||
"Authorization": f"{tokens.token_type} {tokens.access_token}"
|
||||
}
|
||||
config.config["headers"] = cfg_headers
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
if self.alt_config_id:
|
||||
update_connection_config(self.alt_config_id, db_session, config.config)
|
||||
update_connection_config(self.alt_config_id, db_session, config_data)
|
||||
|
||||
# signal the oauth callback that token exchange is complete
|
||||
r = get_redis_client()
|
||||
@@ -168,19 +171,21 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
client_info_raw = config.config.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
config_data = extract_connection_data(config)
|
||||
client_info_raw = config_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if client_info_raw:
|
||||
return OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
if self.alt_config_id:
|
||||
alt_config = get_connection_config_by_id(self.alt_config_id, db_session)
|
||||
if alt_config:
|
||||
alt_client_info = alt_config.config.get(
|
||||
alt_config_data = extract_connection_data(alt_config)
|
||||
alt_client_info = alt_config_data.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if alt_client_info:
|
||||
# Cache the admin client info on the user config for future calls
|
||||
config.config[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
config_data[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
return OAuthClientInformationFull.model_validate(
|
||||
alt_client_info
|
||||
)
|
||||
@@ -189,10 +194,11 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def set_client_info(self, info: OAuthClientInformationFull) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
config.config[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
config_data = extract_connection_data(config)
|
||||
config_data[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
if self.alt_config_id:
|
||||
update_connection_config(self.alt_config_id, db_session, config.config)
|
||||
update_connection_config(self.alt_config_id, db_session, config_data)
|
||||
|
||||
|
||||
def make_oauth_provider(
|
||||
@@ -436,9 +442,12 @@ async def _connect_oauth(
|
||||
|
||||
db.commit()
|
||||
|
||||
connection_config_dict = extract_connection_data(
|
||||
connection_config, apply_mask=False
|
||||
)
|
||||
is_connected = (
|
||||
MCPOAuthKeys.CLIENT_INFO.value in connection_config.config
|
||||
and connection_config.config.get("headers")
|
||||
MCPOAuthKeys.CLIENT_INFO.value in connection_config_dict
|
||||
and connection_config_dict.get("headers")
|
||||
)
|
||||
# Step 1: make unauthenticated request and parse returned www authenticate header
|
||||
# Ensure we have a trailing slash for the MCP endpoint
|
||||
@@ -471,7 +480,7 @@ async def _connect_oauth(
|
||||
try:
|
||||
x = await initialize_mcp_client(
|
||||
probe_url,
|
||||
connection_headers=connection_config.config.get("headers", {}),
|
||||
connection_headers=connection_config_dict.get("headers", {}),
|
||||
transport=transport,
|
||||
auth=oauth_auth,
|
||||
)
|
||||
@@ -684,15 +693,18 @@ def save_user_credentials(
|
||||
# Use template to create the full connection config
|
||||
try:
|
||||
# TODO: fix and/or type correctly w/base model
|
||||
auth_template_dict = extract_connection_data(
|
||||
auth_template, apply_mask=False
|
||||
)
|
||||
config_data = MCPConnectionData(
|
||||
headers=auth_template.config.get("headers", {}),
|
||||
headers=auth_template_dict.get("headers", {}),
|
||||
header_substitutions=request.credentials,
|
||||
)
|
||||
for oauth_field_key in MCPOAuthKeys:
|
||||
field_key: Literal["client_info", "tokens", "metadata"] = (
|
||||
oauth_field_key.value
|
||||
)
|
||||
if field_val := auth_template.config.get(field_key):
|
||||
if field_val := auth_template_dict.get(field_key):
|
||||
config_data[field_key] = field_val
|
||||
|
||||
except Exception as e:
|
||||
@@ -839,18 +851,20 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
and db_server.admin_connection_config is not None
|
||||
and include_auth_config
|
||||
):
|
||||
admin_config_dict = extract_connection_data(
|
||||
db_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
if db_server.auth_type == MCPAuthenticationType.API_TOKEN:
|
||||
raw_api_key = admin_config_dict["headers"]["Authorization"].split(" ")[
|
||||
-1
|
||||
]
|
||||
admin_credentials = {
|
||||
"api_key": db_server.admin_connection_config.config["headers"][
|
||||
"Authorization"
|
||||
].split(" ")[-1]
|
||||
"api_key": mask_string(raw_api_key),
|
||||
}
|
||||
elif db_server.auth_type == MCPAuthenticationType.OAUTH:
|
||||
user_authenticated = False
|
||||
client_info = None
|
||||
client_info_raw = db_server.admin_connection_config.config.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
client_info_raw = admin_config_dict.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(
|
||||
client_info_raw
|
||||
@@ -861,8 +875,8 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
"Stored client info had empty client ID or secret"
|
||||
)
|
||||
admin_credentials = {
|
||||
"client_id": client_info.client_id,
|
||||
"client_secret": client_info.client_secret,
|
||||
"client_id": mask_string(client_info.client_id),
|
||||
"client_secret": mask_string(client_info.client_secret),
|
||||
}
|
||||
else:
|
||||
admin_credentials = {}
|
||||
@@ -879,14 +893,18 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
include_auth_config
|
||||
and db_server.auth_type != MCPAuthenticationType.OAUTH
|
||||
):
|
||||
user_credentials = user_config.config.get(HEADER_SUBSTITUTIONS, {})
|
||||
user_config_dict = extract_connection_data(user_config, apply_mask=True)
|
||||
user_credentials = user_config_dict.get(HEADER_SUBSTITUTIONS, {})
|
||||
|
||||
if (
|
||||
db_server.auth_type == MCPAuthenticationType.OAUTH
|
||||
and db_server.admin_connection_config
|
||||
):
|
||||
client_info = None
|
||||
client_info_raw = db_server.admin_connection_config.config.get(
|
||||
oauth_admin_config_dict = extract_connection_data(
|
||||
db_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
client_info_raw = oauth_admin_config_dict.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if client_info_raw:
|
||||
@@ -896,8 +914,8 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
raise ValueError("Stored client info had empty client ID or secret")
|
||||
if can_view_admin_credentials:
|
||||
admin_credentials = {
|
||||
"client_id": client_info.client_id,
|
||||
"client_secret": client_info.client_secret,
|
||||
"client_id": mask_string(client_info.client_id),
|
||||
"client_secret": mask_string(client_info.client_secret),
|
||||
}
|
||||
elif can_view_admin_credentials:
|
||||
admin_credentials = {}
|
||||
@@ -909,7 +927,10 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
try:
|
||||
template_config = db_server.admin_connection_config
|
||||
if template_config:
|
||||
headers = template_config.config.get("headers", {})
|
||||
template_config_dict = extract_connection_data(
|
||||
template_config, apply_mask=False
|
||||
)
|
||||
headers = template_config_dict.get("headers", {})
|
||||
auth_template = MCPAuthTemplate(
|
||||
headers=headers,
|
||||
required_fields=[], # would need to regex, not worth it
|
||||
@@ -1232,7 +1253,10 @@ def _list_mcp_tools_by_id(
|
||||
)
|
||||
|
||||
if connection_config:
|
||||
headers.update(connection_config.config.get("headers", {}))
|
||||
connection_config_dict = extract_connection_data(
|
||||
connection_config, apply_mask=False
|
||||
)
|
||||
headers.update(connection_config_dict.get("headers", {}))
|
||||
|
||||
import time
|
||||
|
||||
@@ -1320,7 +1344,10 @@ def _upsert_mcp_server(
|
||||
_ensure_mcp_server_owner_or_admin(mcp_server, user)
|
||||
client_info = None
|
||||
if mcp_server.admin_connection_config:
|
||||
client_info_raw = mcp_server.admin_connection_config.config.get(
|
||||
existing_admin_config_dict = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
client_info_raw = existing_admin_config_dict.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if client_info_raw:
|
||||
|
||||
@@ -32,11 +32,16 @@ def get_user_oauth_token_status(
|
||||
and whether their tokens are expired.
|
||||
"""
|
||||
user_tokens = get_all_user_oauth_tokens(user.id, db_session)
|
||||
return [
|
||||
OAuthTokenStatus(
|
||||
oauth_config_id=token.oauth_config_id,
|
||||
expires_at=OAuthTokenManager.token_expiration_time(token.token_data),
|
||||
is_expired=OAuthTokenManager.is_token_expired(token.token_data),
|
||||
result = []
|
||||
for token in user_tokens:
|
||||
token_data = (
|
||||
token.token_data.get_value(apply_mask=False) if token.token_data else {}
|
||||
)
|
||||
for token in user_tokens
|
||||
]
|
||||
result.append(
|
||||
OAuthTokenStatus(
|
||||
oauth_config_id=token.oauth_config_id,
|
||||
expires_at=OAuthTokenManager.token_expiration_time(token_data),
|
||||
is_expired=OAuthTokenManager.is_token_expired(token_data),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -75,7 +75,7 @@ def _get_active_search_provider(
|
||||
has_api_key=bool(provider_model.api_key),
|
||||
)
|
||||
|
||||
if not provider_model.api_key:
|
||||
if provider_model.api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web search provider requires an API key.",
|
||||
@@ -84,7 +84,7 @@ def _get_active_search_provider(
|
||||
try:
|
||||
provider: WebSearchProvider = build_search_provider_from_config(
|
||||
provider_type=provider_view.provider_type,
|
||||
api_key=provider_model.api_key,
|
||||
api_key=provider_model.api_key.get_value(apply_mask=False),
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -121,7 +121,7 @@ def _get_active_content_provider(
|
||||
|
||||
provider: WebContentProvider | None = build_content_provider_from_config(
|
||||
provider_type=provider_type,
|
||||
api_key=provider_model.api_key,
|
||||
api_key=provider_model.api_key.get_value(apply_mask=False),
|
||||
config=config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -114,9 +114,14 @@ def get_entities(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
entities_spec = connector_instance.configuration_schema()
|
||||
|
||||
@@ -151,9 +156,14 @@ def get_credentials_schema(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
credentials_spec = connector_instance.credentials_schema()
|
||||
|
||||
@@ -275,6 +285,8 @@ def validate_entities(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
return Response(status_code=400)
|
||||
|
||||
# For HEAD requests, we'll expect entities as query parameters
|
||||
# since HEAD requests shouldn't have request bodies
|
||||
@@ -288,7 +300,8 @@ def validate_entities(
|
||||
return Response(status_code=400)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
is_valid = connector_instance.validate_entities(entities_dict)
|
||||
|
||||
@@ -318,9 +331,15 @@ def get_authorize_url(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
# Update credentials to include the correct redirect URI with the connector ID
|
||||
updated_credentials = federated_connector.credentials.copy()
|
||||
updated_credentials = federated_connector.credentials.get_value(
|
||||
apply_mask=False
|
||||
).copy()
|
||||
if "redirect_uri" in updated_credentials and updated_credentials["redirect_uri"]:
|
||||
# Replace the {id} placeholder with the actual federated connector ID
|
||||
updated_credentials["redirect_uri"] = updated_credentials[
|
||||
@@ -391,9 +410,14 @@ def handle_oauth_callback_generic(
|
||||
)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
oauth_result = connector_instance.callback(callback_data, get_oauth_callback_uri())
|
||||
|
||||
@@ -460,9 +484,9 @@ def get_user_oauth_status(
|
||||
|
||||
# Generate authorize URL if needed
|
||||
authorize_url = None
|
||||
if not oauth_token:
|
||||
if not oauth_token and fc.credentials is not None:
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
fc.source, fc.credentials
|
||||
fc.source, fc.credentials.get_value(apply_mask=False)
|
||||
)
|
||||
base_authorize_url = connector_instance.authorize(get_oauth_callback_uri())
|
||||
|
||||
@@ -496,6 +520,10 @@ def get_federated_connector_detail(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
# Get OAuth token information for the current user
|
||||
oauth_token = None
|
||||
@@ -521,7 +549,9 @@ def get_federated_connector_detail(
|
||||
id=federated_connector.id,
|
||||
source=federated_connector.source,
|
||||
name=f"{federated_connector.source.replace('_', ' ').title()}",
|
||||
credentials=FederatedConnectorCredentials(**federated_connector.credentials),
|
||||
credentials=FederatedConnectorCredentials(
|
||||
**federated_connector.credentials.get_value(apply_mask=True)
|
||||
),
|
||||
config=federated_connector.config,
|
||||
oauth_token_exists=oauth_token is not None,
|
||||
oauth_token_expires_at=oauth_token.expires_at if oauth_token else None,
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
@@ -37,7 +37,11 @@ class CloudEmbeddingProvider(BaseModel):
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
provider_type=cloud_provider_model.provider_type,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
api_key=(
|
||||
cloud_provider_model.api_key.get_value(apply_mask=True)
|
||||
if cloud_provider_model.api_key
|
||||
else None
|
||||
),
|
||||
api_url=cloud_provider_model.api_url,
|
||||
api_version=cloud_provider_model.api_version,
|
||||
deployment_name=cloud_provider_model.deployment_name,
|
||||
|
||||
@@ -90,7 +90,11 @@ def _build_llm_provider_request(
|
||||
return LLMProviderUpsertRequest(
|
||||
name=f"Image Gen - {image_provider_id}",
|
||||
provider=source_provider.provider,
|
||||
api_key=source_provider.api_key, # Only this from source
|
||||
api_key=(
|
||||
source_provider.api_key.get_value(apply_mask=False)
|
||||
if source_provider.api_key
|
||||
else None
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
@@ -227,7 +231,11 @@ def test_image_generation(
|
||||
api_key_changed=False, # Using stored key from source provider
|
||||
)
|
||||
|
||||
api_key = source_provider.api_key
|
||||
api_key = (
|
||||
source_provider.api_key.get_value(apply_mask=False)
|
||||
if source_provider.api_key
|
||||
else None
|
||||
)
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
@@ -431,7 +439,11 @@ def update_config(
|
||||
api_key_changed=False,
|
||||
)
|
||||
# Preserve existing API key when user didn't change it
|
||||
actual_api_key = old_provider.api_key
|
||||
actual_api_key = (
|
||||
old_provider.api_key.get_value(apply_mask=False)
|
||||
if old_provider.api_key
|
||||
else None
|
||||
)
|
||||
|
||||
# 3. Build and create new LLM provider
|
||||
provider_request = _build_llm_provider_request(
|
||||
|
||||
@@ -140,7 +140,11 @@ class ImageGenerationCredentials(BaseModel):
|
||||
"""
|
||||
llm_provider = config.model_configuration.llm_provider
|
||||
return cls(
|
||||
api_key=_mask_api_key(llm_provider.api_key),
|
||||
api_key=_mask_api_key(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
@@ -168,7 +172,11 @@ class DefaultImageGenerationConfig(BaseModel):
|
||||
model_configuration_id=config.model_configuration_id,
|
||||
model_name=config.model_configuration.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -203,7 +203,11 @@ def test_llm_configuration(
|
||||
new_custom_config=test_llm_request.custom_config,
|
||||
api_key_changed=False,
|
||||
)
|
||||
test_api_key = existing_provider.api_key
|
||||
test_api_key = (
|
||||
existing_provider.api_key.get_value(apply_mask=False)
|
||||
if existing_provider.api_key
|
||||
else None
|
||||
)
|
||||
if existing_provider and not test_llm_request.custom_config_changed:
|
||||
test_custom_config = existing_provider.custom_config
|
||||
|
||||
@@ -255,7 +259,7 @@ def list_llm_providers(
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(
|
||||
db_session=db_session,
|
||||
flow_type_filter=[],
|
||||
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
|
||||
exclude_image_generation_providers=not include_image_gen,
|
||||
):
|
||||
from_model_start = datetime.now(timezone.utc)
|
||||
@@ -351,7 +355,11 @@ def put_llm_provider(
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
llm_provider_upsert_request.api_key = existing_provider.api_key
|
||||
llm_provider_upsert_request.api_key = (
|
||||
existing_provider.api_key.get_value(apply_mask=False)
|
||||
if existing_provider.api_key
|
||||
else None
|
||||
)
|
||||
if existing_provider and not llm_provider_upsert_request.custom_config_changed:
|
||||
llm_provider_upsert_request.custom_config = existing_provider.custom_config
|
||||
|
||||
@@ -503,7 +511,9 @@ def list_llm_provider_basics(
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch user-accessible LLM providers")
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session, [])
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
@@ -512,9 +522,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -539,7 +549,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
Public providers are always included.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -553,7 +563,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -574,7 +584,7 @@ def list_llm_providers_for_persona(
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -603,7 +613,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Check access with persona context — respects persona restrictions
|
||||
# Use simplified access check - public providers always included
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -644,7 +654,11 @@ def get_provider_contextual_cost(
|
||||
provider=provider.provider,
|
||||
model=model_configuration.name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_key=(
|
||||
provider.api_key.get_value(apply_mask=False)
|
||||
if provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
@@ -924,6 +938,11 @@ def get_ollama_available_models(
|
||||
)
|
||||
)
|
||||
|
||||
sorted_results = sorted(
|
||||
all_models_with_context_size_and_vision,
|
||||
key=lambda m: m.name.lower(),
|
||||
)
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
@@ -934,7 +953,7 @@ def get_ollama_available_models(
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in all_models_with_context_size_and_vision
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
@@ -948,7 +967,7 @@ def get_ollama_available_models(
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
|
||||
return all_models_with_context_size_and_vision
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
|
||||
@@ -190,7 +190,11 @@ class LLMProviderView(LLMProvider):
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
api_key=llm_provider_model.api_key,
|
||||
api_key=(
|
||||
llm_provider_model.api_key.get_value(apply_mask=False)
|
||||
if llm_provider_model.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
|
||||
@@ -79,6 +79,7 @@ class UserPersonalization(BaseModel):
|
||||
role: str = ""
|
||||
use_memories: bool = True
|
||||
memories: list[str] = Field(default_factory=list)
|
||||
user_preferences: str = ""
|
||||
|
||||
|
||||
class TenantSnapshot(BaseModel):
|
||||
@@ -160,6 +161,7 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
memories=[memory.memory_text for memory in (user.memories or [])],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -213,6 +215,7 @@ class PersonalizationUpdateRequest(BaseModel):
|
||||
role: str | None = None
|
||||
use_memories: bool | None = None
|
||||
memories: list[str] | None = None
|
||||
user_preferences: str | None = Field(default=None, max_length=500)
|
||||
|
||||
|
||||
class SlackBotCreationRequest(BaseModel):
|
||||
@@ -341,9 +344,21 @@ class SlackBot(BaseModel):
|
||||
name=slack_bot_model.name,
|
||||
enabled=slack_bot_model.enabled,
|
||||
configs_count=len(slack_bot_model.slack_channel_configs),
|
||||
bot_token=slack_bot_model.bot_token,
|
||||
app_token=slack_bot_model.app_token,
|
||||
user_token=slack_bot_model.user_token,
|
||||
bot_token=(
|
||||
slack_bot_model.bot_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.bot_token
|
||||
else ""
|
||||
),
|
||||
app_token=(
|
||||
slack_bot_model.app_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.app_token
|
||||
else ""
|
||||
),
|
||||
user_token=(
|
||||
slack_bot_model.user_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.user_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -30,14 +30,12 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import enforce_seat_limit
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
@@ -92,7 +90,6 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -394,20 +391,14 @@ def bulk_invite_users(
|
||||
if e not in existing_users and e not in already_invited
|
||||
]
|
||||
|
||||
# Limit bulk invites for trial tenants to prevent email spam
|
||||
# Only count new invites, not re-invites of existing users
|
||||
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
|
||||
current_invited = len(already_invited)
|
||||
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You have hit your invite limit. "
|
||||
"Please upgrade for unlimited invites.",
|
||||
)
|
||||
|
||||
# Check seat availability for new users
|
||||
if emails_needing_seats:
|
||||
enforce_seat_limit(db_session, seats_needed=len(emails_needing_seats))
|
||||
# Only for self-hosted (non-multi-tenant) deployments
|
||||
if not MULTI_TENANT and emails_needing_seats:
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=len(emails_needing_seats))
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
@@ -423,10 +414,10 @@ def bulk_invite_users(
|
||||
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
# send out email invitations only to new users (not already invited or existing)
|
||||
# send out email invitations if enabled
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in emails_needing_seats:
|
||||
for email in new_invited_emails:
|
||||
send_user_email_invite(email, current_user, AUTH_TYPE)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
@@ -573,7 +564,12 @@ def activate_user_api(
|
||||
|
||||
# Check seat availability before activating
|
||||
# Only for self-hosted (non-multi-tenant) deployments
|
||||
enforce_seat_limit(db_session)
|
||||
if not MULTI_TENANT:
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=1)
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
activate_user(user_to_activate, db_session)
|
||||
|
||||
@@ -597,17 +593,11 @@ def get_valid_domains(
|
||||
|
||||
@router.get("/users", tags=PUBLIC_API_TAGS)
|
||||
def list_all_users_basic_info(
|
||||
include_api_keys: bool = False,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserSnapshot]:
|
||||
users = get_all_users(db_session)
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]
|
||||
|
||||
|
||||
@router.get("/get-user-role", tags=PUBLIC_API_TAGS)
|
||||
@@ -854,6 +844,11 @@ def update_user_personalization_api(
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
)
|
||||
new_user_preferences = (
|
||||
request.user_preferences
|
||||
if request.user_preferences is not None
|
||||
else user.user_preferences
|
||||
)
|
||||
|
||||
update_user_personalization(
|
||||
user.id,
|
||||
@@ -861,6 +856,7 @@ def update_user_personalization_api(
|
||||
personal_role=new_role,
|
||||
use_memories=new_use_memories,
|
||||
memories=new_memories,
|
||||
user_preferences=new_user_preferences,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ def test_search_provider(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if provider_requires_api_key and not api_key:
|
||||
raise HTTPException(
|
||||
@@ -391,7 +391,7 @@ def test_content_provider(
|
||||
detail="Base URL cannot differ from stored provider when using stored API key",
|
||||
)
|
||||
|
||||
api_key = existing_provider.api_key
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -57,11 +57,9 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
deep_research_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -8,10 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
@@ -45,42 +41,6 @@ def get_json_line(
|
||||
return json.dumps(json_dict, cls=encoder) + "\n"
|
||||
|
||||
|
||||
def mask_string(sensitive_str: str) -> str:
|
||||
return "****...**" + sensitive_str[-4:]
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
masked_creds = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
continue
|
||||
|
||||
if isinstance(val, int):
|
||||
masked_creds[key] = "*****"
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string or int, cannot process request."
|
||||
f"Received type: {type(val)}"
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def make_short_id() -> str:
|
||||
"""Fast way to generate a random 8 character id ... useful for tagging data
|
||||
to trace it through a flow. This is definitely not guaranteed to be unique and is
|
||||
|
||||
@@ -446,7 +446,7 @@ def run_research_agent_call(
|
||||
tool_calls=tool_calls,
|
||||
tools=current_tools,
|
||||
message_history=msg_history,
|
||||
memories=None,
|
||||
user_memory_context=None,
|
||||
user_info=None,
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
@@ -165,7 +166,7 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
# without help and a specific custom prompt for this
|
||||
original_query: str | None = None
|
||||
message_history: list[ChatMinimalTextMessage] | None = None
|
||||
memories: list[str] | None = None
|
||||
user_memory_context: UserMemoryContext | None = None
|
||||
user_info: str | None = None
|
||||
|
||||
# Used for tool calls after the first one but in the same chat turn. The reason for this is that if the initial pass through
|
||||
|
||||
@@ -82,7 +82,11 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
model_provider=llm_provider.provider,
|
||||
model_name=default_config.model_configuration.name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -94,7 +94,11 @@ class ImageGenerationTool(Tool[None]):
|
||||
|
||||
llm_provider = config.model_configuration.llm_provider
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -142,8 +142,9 @@ class MCPTool(Tool[None]):
|
||||
)
|
||||
|
||||
# Priority 2: Base headers from connection config (DB) - overrides request
|
||||
if self.connection_config:
|
||||
headers.update(self.connection_config.config.get("headers", {}))
|
||||
if self.connection_config and self.connection_config.config:
|
||||
config_dict = self.connection_config.config.get_value(apply_mask=False)
|
||||
headers.update(config_dict.get("headers", {}))
|
||||
|
||||
# Priority 3: For pass-through OAuth, use the user's login OAuth token
|
||||
if self._user_oauth_token:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -22,22 +21,10 @@ from onyx.utils.web_content import title_from_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_READ_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def _failed_result(url: str) -> WebContent:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
|
||||
class OnyxWebCrawler(WebContentProvider):
|
||||
@@ -50,14 +37,12 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
max_pdf_size_bytes: int | None = None,
|
||||
max_html_size_bytes: int | None = None,
|
||||
) -> None:
|
||||
self._read_timeout_seconds = timeout_seconds
|
||||
self._connect_timeout_seconds = connect_timeout_seconds
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._max_pdf_size_bytes = max_pdf_size_bytes
|
||||
self._max_html_size_bytes = max_html_size_bytes
|
||||
self._headers = {
|
||||
@@ -66,68 +51,75 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._fetch_url_safe, urls))
|
||||
|
||||
def _fetch_url_safe(self, url: str) -> WebContent:
|
||||
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
|
||||
try:
|
||||
return self._fetch_url(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler unexpected error for %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
# Use SSRF-safe request to prevent DNS rebinding attacks
|
||||
response = ssrf_safe_get(
|
||||
url,
|
||||
headers=self._headers,
|
||||
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
)
|
||||
except SSRFException as exc:
|
||||
logger.error(
|
||||
"SSRF protection blocked request to %s (%s)",
|
||||
"SSRF protection blocked request to %s: %s",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
str(exc),
|
||||
)
|
||||
return _failed_result(url)
|
||||
except Exception as exc:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content = response.content
|
||||
|
||||
content_sniff = content[:1024] if content else None
|
||||
content_sniff = response.content[:1024] if response.content else None
|
||||
if is_pdf_resource(url, content_type, content_sniff):
|
||||
if (
|
||||
self._max_pdf_size_bytes is not None
|
||||
and len(content) > self._max_pdf_size_bytes
|
||||
and len(response.content) > self._max_pdf_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"PDF content too large (%d bytes) for %s, max is %d",
|
||||
len(content),
|
||||
len(response.content),
|
||||
url,
|
||||
self._max_pdf_size_bytes,
|
||||
)
|
||||
return _failed_result(url)
|
||||
text_content, metadata = extract_pdf_text(content)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
text_content, metadata = extract_pdf_text(response.content)
|
||||
title = title_from_pdf_metadata(metadata) or title_from_url(url)
|
||||
return WebContent(
|
||||
title=title,
|
||||
@@ -139,19 +131,25 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
|
||||
if (
|
||||
self._max_html_size_bytes is not None
|
||||
and len(content) > self._max_html_size_bytes
|
||||
and len(response.content) > self._max_html_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"HTML content too large (%d bytes) for %s, max is %d",
|
||||
len(content),
|
||||
len(response.content),
|
||||
url,
|
||||
self._max_html_size_bytes,
|
||||
)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
try:
|
||||
decoded_html = decode_html_bytes(
|
||||
content,
|
||||
response.content,
|
||||
content_type=content_type,
|
||||
fallback_encoding=response.apparent_encoding or response.encoding,
|
||||
)
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.tools.tool_implementations.web_search.utils import (
|
||||
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.url import normalize_url as normalize_web_content_url
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -792,9 +791,7 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
for url in all_urls:
|
||||
doc_id = url_to_doc_id.get(url)
|
||||
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
|
||||
# WebContent.link is normalized (query/fragment stripped). Match on the
|
||||
# same normalized form to avoid dropping successful crawl results.
|
||||
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
|
||||
crawled_section = crawled_by_url.get(url)
|
||||
|
||||
if indexed_section and indexed_section.combined_content:
|
||||
# Prefer indexed
|
||||
|
||||
@@ -352,10 +352,17 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
if tenant_slack_bot:
|
||||
bot_token = tenant_slack_bot.bot_token
|
||||
access_token = (
|
||||
tenant_slack_bot.user_token or tenant_slack_bot.bot_token
|
||||
bot_token = (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else None
|
||||
)
|
||||
user_token = (
|
||||
tenant_slack_bot.user_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.user_token
|
||||
else None
|
||||
)
|
||||
access_token = user_token or bot_token
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch Slack bot tokens: {e}")
|
||||
|
||||
@@ -375,8 +382,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
None,
|
||||
)
|
||||
|
||||
if slack_oauth_token:
|
||||
access_token = slack_oauth_token.token
|
||||
if slack_oauth_token and slack_oauth_token.token:
|
||||
access_token = slack_oauth_token.token.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
entities = slack_oauth_token.federated_connector.config or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch Slack OAuth token: {e}")
|
||||
@@ -550,7 +559,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
if override_kwargs.message_history
|
||||
else []
|
||||
)
|
||||
memories = override_kwargs.memories
|
||||
memories = (
|
||||
override_kwargs.user_memory_context.as_formatted_list()
|
||||
if override_kwargs.user_memory_context
|
||||
else []
|
||||
)
|
||||
user_info = override_kwargs.user_info
|
||||
|
||||
# Skip query expansion if this is a repeat search call
|
||||
|
||||
@@ -1,260 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
BRAVE_MAX_RESULTS_PER_REQUEST = 20
|
||||
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
|
||||
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
|
||||
|
||||
|
||||
class RetryableBraveSearchError(Exception):
|
||||
"""Error type used to trigger retry for transient Brave search failures."""
|
||||
|
||||
|
||||
class BraveClient(WebSearchProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
num_results: int = 10,
|
||||
timeout_seconds: int = 10,
|
||||
country: str | None = None,
|
||||
search_lang: str | None = None,
|
||||
ui_lang: str | None = None,
|
||||
safesearch: str | None = None,
|
||||
freshness: str | None = None,
|
||||
) -> None:
|
||||
if timeout_seconds <= 0:
|
||||
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
|
||||
|
||||
self._headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": api_key,
|
||||
}
|
||||
logger.debug(f"Count of results passed to BraveClient: {num_results}")
|
||||
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._country = _normalize_country(country)
|
||||
self._search_lang = _normalize_language_code(
|
||||
search_lang, field_name="search_lang"
|
||||
)
|
||||
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
|
||||
self._safesearch = _normalize_option(
|
||||
safesearch,
|
||||
field_name="safesearch",
|
||||
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
|
||||
)
|
||||
self._freshness = _normalize_option(
|
||||
freshness,
|
||||
field_name="freshness",
|
||||
allowed_values=BRAVE_FRESHNESS_OPTIONS,
|
||||
)
|
||||
|
||||
def _build_search_params(self, query: str) -> dict[str, str]:
|
||||
params = {
|
||||
"q": query,
|
||||
"count": str(self._num_results),
|
||||
}
|
||||
if self._country:
|
||||
params["country"] = self._country
|
||||
if self._search_lang:
|
||||
params["search_lang"] = self._search_lang
|
||||
if self._ui_lang:
|
||||
params["ui_lang"] = self._ui_lang
|
||||
if self._safesearch:
|
||||
params["safesearch"] = self._safesearch
|
||||
if self._freshness:
|
||||
params["freshness"] = self._freshness
|
||||
return params
|
||||
|
||||
@retry_builder(
|
||||
tries=3,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
exceptions=(RetryableBraveSearchError,),
|
||||
)
|
||||
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
|
||||
params = self._build_search_params(query)
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
BRAVE_WEB_SEARCH_URL,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=self._timeout_seconds,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise RetryableBraveSearchError(
|
||||
f"Brave search request failed: {exc}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as exc:
|
||||
error_msg = _build_error_message(response)
|
||||
if _is_retryable_status(response.status_code):
|
||||
raise RetryableBraveSearchError(error_msg) from exc
|
||||
raise ValueError(error_msg) from exc
|
||||
|
||||
data = response.json()
|
||||
web_results = (data.get("web") or {}).get("results") or []
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
for result in web_results:
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
|
||||
link = _clean_string(result.get("url"))
|
||||
if not link:
|
||||
continue
|
||||
|
||||
title = _clean_string(result.get("title"))
|
||||
description = _clean_string(result.get("description"))
|
||||
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
link=link,
|
||||
snippet=description,
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
try:
|
||||
return self._search_with_retries(query)
|
||||
except RetryableBraveSearchError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Brave API key validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except (ValueError, requests.RequestException) as e:
|
||||
error_msg = str(e)
|
||||
lower = error_msg.lower()
|
||||
if (
|
||||
"status 401" in lower
|
||||
or "status 403" in lower
|
||||
or "api key" in lower
|
||||
or "auth" in lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid Brave API key: {error_msg}",
|
||||
) from e
|
||||
if "status 429" in lower or "rate limit" in lower:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Brave API rate limit exceeded: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Brave API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info("Web search provider test succeeded for Brave.")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def _build_error_message(response: requests.Response) -> str:
|
||||
return (
|
||||
"Brave search failed "
|
||||
f"(status {response.status_code}): {_extract_error_detail(response)}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_error_detail(response: requests.Response) -> str:
|
||||
try:
|
||||
payload: Any = response.json()
|
||||
except Exception:
|
||||
text = response.text.strip()
|
||||
return text[:200] if text else "No error details"
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict):
|
||||
detail = error.get("detail") or error.get("message")
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
if isinstance(error, str):
|
||||
return error
|
||||
|
||||
message = payload.get("message")
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
return str(payload)[:200]
|
||||
|
||||
|
||||
def _is_retryable_status(status_code: int) -> bool:
|
||||
return status_code == 429 or status_code >= 500
|
||||
|
||||
|
||||
def _clean_string(value: Any) -> str:
|
||||
return value.strip() if isinstance(value, str) else ""
|
||||
|
||||
|
||||
def _normalize_country(country: str | None) -> str | None:
|
||||
if country is None:
|
||||
return None
|
||||
normalized = country.strip().upper()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) != 2 or not normalized.isalpha():
|
||||
raise ValueError(
|
||||
"Brave provider config 'country' must be a 2-letter ISO country code."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) > 20:
|
||||
raise ValueError(f"Brave provider config '{field_name}' is too long.")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_option(
|
||||
value: str | None,
|
||||
*,
|
||||
field_name: str,
|
||||
allowed_values: set[str],
|
||||
) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
if normalized not in allowed_values:
|
||||
allowed = ", ".join(sorted(allowed_values))
|
||||
raise ValueError(
|
||||
f"Brave provider config '{field_name}' must be one of: {allowed}."
|
||||
)
|
||||
return normalized
|
||||
@@ -13,9 +13,6 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
|
||||
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
|
||||
BraveClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
@@ -38,76 +35,16 @@ from shared_configs.enums import WebSearchProviderType
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _parse_positive_int_config(
|
||||
*,
|
||||
raw_value: str | None,
|
||||
default: int,
|
||||
provider_name: str,
|
||||
config_key: str,
|
||||
) -> int:
|
||||
if not raw_value:
|
||||
return default
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"{provider_name} provider config '{config_key}' must be an integer."
|
||||
) from exc
|
||||
if value <= 0:
|
||||
raise ValueError(
|
||||
f"{provider_name} provider config '{config_key}' must be greater than 0."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
|
||||
"""Return True if the given provider type requires an API key.
|
||||
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
|
||||
require an API key. You can also set it up in a way which requires a key but SearXNG itself does not require a key.
|
||||
"""
|
||||
return provider_type != WebSearchProviderType.SEARXNG
|
||||
|
||||
|
||||
def build_search_provider_from_config(
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str | None,
|
||||
api_key: str,
|
||||
config: dict[str, str] | None, # TODO use a typed object
|
||||
) -> WebSearchProvider:
|
||||
config = config or {}
|
||||
num_results = int(config.get("num_results") or DEFAULT_MAX_RESULTS)
|
||||
|
||||
# SearXNG does not require an API key
|
||||
if provider_type == WebSearchProviderType.SEARXNG:
|
||||
searxng_base_url = config.get("searxng_base_url")
|
||||
if not searxng_base_url:
|
||||
raise ValueError("Please provide a URL for your private SearXNG instance.")
|
||||
return SearXNGClient(
|
||||
searxng_base_url,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
# All other providers require an API key
|
||||
if not api_key:
|
||||
raise ValueError(f"API key is required for {provider_type.value} provider.")
|
||||
|
||||
if provider_type == WebSearchProviderType.EXA:
|
||||
return ExaClient(api_key=api_key, num_results=num_results)
|
||||
if provider_type == WebSearchProviderType.BRAVE:
|
||||
return BraveClient(
|
||||
api_key=api_key,
|
||||
num_results=num_results,
|
||||
timeout_seconds=_parse_positive_int_config(
|
||||
raw_value=config.get("timeout_seconds"),
|
||||
default=10,
|
||||
provider_name="Brave",
|
||||
config_key="timeout_seconds",
|
||||
),
|
||||
country=config.get("country"),
|
||||
search_lang=config.get("search_lang"),
|
||||
ui_lang=config.get("ui_lang"),
|
||||
safesearch=config.get("safesearch"),
|
||||
freshness=config.get("freshness"),
|
||||
)
|
||||
if provider_type == WebSearchProviderType.SERPER:
|
||||
return SerperClient(api_key=api_key, num_results=num_results)
|
||||
if provider_type == WebSearchProviderType.GOOGLE_PSE:
|
||||
@@ -127,13 +64,24 @@ def build_search_provider_from_config(
|
||||
num_results=num_results,
|
||||
timeout_seconds=int(config.get("timeout_seconds") or 10),
|
||||
)
|
||||
raise ValueError(f"Unknown provider type: {provider_type.value}")
|
||||
if provider_type == WebSearchProviderType.SEARXNG:
|
||||
searxng_base_url = config.get("searxng_base_url")
|
||||
if not searxng_base_url:
|
||||
raise ValueError("Please provide a URL for your private SearXNG instance.")
|
||||
return SearXNGClient(
|
||||
searxng_base_url,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
|
||||
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
|
||||
return build_search_provider_from_config(
|
||||
provider_type=WebSearchProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key,
|
||||
api_key=(
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else ""
|
||||
),
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
|
||||
@@ -185,7 +133,11 @@ def get_default_content_provider() -> WebContentProvider:
|
||||
if provider_model:
|
||||
provider = build_content_provider_from_config(
|
||||
provider_type=WebContentProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key or "",
|
||||
api_key=(
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else ""
|
||||
),
|
||||
config=provider_model.config or WebContentProviderConfig(),
|
||||
)
|
||||
if provider:
|
||||
|
||||
@@ -69,7 +69,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
if provider_model is None:
|
||||
raise RuntimeError("No web search provider configured.")
|
||||
provider_type = WebSearchProviderType(provider_model.provider_type)
|
||||
api_key = provider_model.api_key
|
||||
api_key = (
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else None
|
||||
)
|
||||
config = provider_model.config
|
||||
|
||||
# TODO - This should just be enforced at the DB level
|
||||
|
||||
@@ -6,6 +6,7 @@ import onyx.tracing.framework._error_tracing as _error_tracing
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -220,7 +221,7 @@ def run_tool_calls(
|
||||
tools: list[Tool],
|
||||
# The stuff below is needed for the different individual built-in tools
|
||||
message_history: list[ChatMessageSimple],
|
||||
memories: list[str] | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
user_info: str | None,
|
||||
citation_mapping: dict[int, str],
|
||||
next_citation_num: int,
|
||||
@@ -252,7 +253,7 @@ def run_tool_calls(
|
||||
tools: List of available tool instances.
|
||||
message_history: Chat message history (used to find the most recent user query
|
||||
for `SearchTool` override kwargs).
|
||||
memories: User memories, if available (passed through to `SearchTool`).
|
||||
user_memory_context: User memory context, if available (passed through to `SearchTool`).
|
||||
user_info: User information string, if available (passed through to `SearchTool`).
|
||||
citation_mapping: Current citation number to URL mapping. May be updated with
|
||||
new citations produced by search tools.
|
||||
@@ -342,7 +343,7 @@ def run_tool_calls(
|
||||
starting_citation_num=starting_citation_num,
|
||||
original_query=last_user_message,
|
||||
message_history=minimal_history,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
user_info=user_info,
|
||||
skip_query_expansion=skip_search_query_expansion,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,91 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
def mask_string(sensitive_str: str) -> str:
|
||||
"""Masks a sensitive string, showing first and last few characters.
|
||||
If the string is too short to safely mask, returns a fully masked placeholder.
|
||||
"""
|
||||
visible_start = 4
|
||||
visible_end = 4
|
||||
min_masked_chars = 6
|
||||
|
||||
if len(sensitive_str) < visible_start + visible_end + min_masked_chars:
|
||||
return "••••••••••••"
|
||||
|
||||
return f"{sensitive_str[:visible_start]}...{sensitive_str[-visible_end:]}"
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
masked_creds: dict[str, Any] = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
elif isinstance(val, dict):
|
||||
masked_creds[key] = mask_credential_dict(val)
|
||||
elif isinstance(val, list):
|
||||
masked_creds[key] = _mask_list(val)
|
||||
elif isinstance(val, (bool, type(None))):
|
||||
masked_creds[key] = val
|
||||
elif isinstance(val, (int, float)):
|
||||
masked_creds[key] = "*****"
|
||||
else:
|
||||
masked_creds[key] = "*****"
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def _mask_list(items: list[Any]) -> list[Any]:
|
||||
masked: list[Any] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
masked.append(mask_credential_dict(item))
|
||||
elif isinstance(item, str):
|
||||
masked.append(mask_string(item))
|
||||
elif isinstance(item, list):
|
||||
masked.append(_mask_list(item))
|
||||
elif isinstance(item, (bool, type(None))):
|
||||
masked.append(item)
|
||||
else:
|
||||
masked.append("*****")
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
|
||||
205
backend/onyx/utils/sensitive.py
Normal file
205
backend/onyx/utils/sensitive.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Wrapper class for sensitive values that require explicit masking decisions.
|
||||
|
||||
This module provides a wrapper for encrypted values that forces developers to
|
||||
make an explicit decision about whether to mask the value when accessing it.
|
||||
This prevents accidental exposure of sensitive data in API responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import NoReturn
|
||||
from typing import TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.utils.encryption import mask_credential_dict
|
||||
from onyx.utils.encryption import mask_string
|
||||
|
||||
|
||||
T = TypeVar("T", str, dict[str, Any])
|
||||
|
||||
|
||||
def make_mock_sensitive_value(value: dict[str, Any] | str | None) -> MagicMock:
|
||||
"""
|
||||
Create a mock SensitiveValue for use in tests.
|
||||
|
||||
This helper makes it easy to create mock objects that behave like
|
||||
SensitiveValue for testing code that uses credentials.
|
||||
|
||||
Args:
|
||||
value: The value to return from get_value(). Can be a dict, string, or None.
|
||||
|
||||
Returns:
|
||||
A MagicMock configured to behave like a SensitiveValue.
|
||||
|
||||
Example:
|
||||
>>> mock_credential = MagicMock()
|
||||
>>> mock_credential.credential_json = make_mock_sensitive_value({"api_key": "secret"})
|
||||
>>> # Now mock_credential.credential_json.get_value(apply_mask=False) returns {"api_key": "secret"}
|
||||
"""
|
||||
if value is None:
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
mock = MagicMock(spec=SensitiveValue)
|
||||
mock.get_value.return_value = value
|
||||
mock.__bool__ = lambda self: True # noqa: ARG005
|
||||
return mock
|
||||
|
||||
|
||||
class SensitiveAccessError(Exception):
|
||||
"""Raised when attempting to access a SensitiveValue without explicit masking decision."""
|
||||
|
||||
|
||||
class SensitiveValue(Generic[T]):
|
||||
"""
|
||||
Wrapper requiring explicit masking decisions for sensitive data.
|
||||
|
||||
This class wraps encrypted data and forces callers to make an explicit
|
||||
decision about whether to mask the value when accessing it. This prevents
|
||||
accidental exposure of sensitive data.
|
||||
|
||||
Usage:
|
||||
# Get raw value (for internal use like connectors)
|
||||
raw_value = sensitive.get_value(apply_mask=False)
|
||||
|
||||
# Get masked value (for API responses)
|
||||
masked_value = sensitive.get_value(apply_mask=True)
|
||||
|
||||
Raises SensitiveAccessError when:
|
||||
- Attempting to convert to string via str() or repr()
|
||||
- Attempting to iterate over the value
|
||||
- Attempting to subscript the value (e.g., value["key"])
|
||||
- Attempting to serialize to JSON without explicit get_value()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encrypted_bytes: bytes,
|
||||
decrypt_fn: Callable[[bytes], str],
|
||||
is_json: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a SensitiveValue wrapper.
|
||||
|
||||
Args:
|
||||
encrypted_bytes: The encrypted bytes to wrap
|
||||
decrypt_fn: Function to decrypt bytes to string
|
||||
is_json: If True, the decrypted value is JSON and will be parsed to dict
|
||||
"""
|
||||
self._encrypted_bytes = encrypted_bytes
|
||||
self._decrypt_fn = decrypt_fn
|
||||
self._is_json = is_json
|
||||
# Cache for decrypted value to avoid repeated decryption
|
||||
self._decrypted_value: T | None = None
|
||||
|
||||
def _decrypt(self) -> T:
|
||||
"""Lazily decrypt and cache the value."""
|
||||
if self._decrypted_value is None:
|
||||
decrypted_str = self._decrypt_fn(self._encrypted_bytes)
|
||||
if self._is_json:
|
||||
self._decrypted_value = json.loads(decrypted_str)
|
||||
else:
|
||||
self._decrypted_value = decrypted_str # type: ignore[assignment]
|
||||
# The return type should always match T based on is_json flag
|
||||
return self._decrypted_value # type: ignore[return-value]
|
||||
|
||||
def get_value(
|
||||
self,
|
||||
*,
|
||||
apply_mask: bool,
|
||||
mask_fn: Callable[[T], T] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Get the value with explicit masking decision.
|
||||
|
||||
Args:
|
||||
apply_mask: Required. True = return masked value, False = return raw value
|
||||
mask_fn: Optional custom masking function. Defaults to mask_string for
|
||||
strings and mask_credential_dict for dicts.
|
||||
|
||||
Returns:
|
||||
The value, either masked or raw depending on apply_mask.
|
||||
"""
|
||||
value = self._decrypt()
|
||||
|
||||
if not apply_mask:
|
||||
return value
|
||||
|
||||
# Apply masking
|
||||
if mask_fn is not None:
|
||||
return mask_fn(value)
|
||||
|
||||
# Use default masking based on type
|
||||
# Type narrowing doesn't work well here due to the generic T,
|
||||
# but at runtime the types will match
|
||||
if isinstance(value, dict):
|
||||
return mask_credential_dict(value)
|
||||
elif isinstance(value, str):
|
||||
return mask_string(value)
|
||||
else:
|
||||
raise ValueError(f"Cannot mask value of type {type(value)}")
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Allow truthiness checks without exposing the value."""
|
||||
return True
|
||||
|
||||
def __str__(self) -> NoReturn:
|
||||
"""Prevent accidental string conversion."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot convert SensitiveValue to string. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Prevent accidental repr exposure."""
|
||||
return "<SensitiveValue: use .get_value(apply_mask=True/False) to access>"
|
||||
|
||||
def __iter__(self) -> NoReturn:
|
||||
"""Prevent iteration over the value."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot iterate over SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __getitem__(self, key: Any) -> NoReturn:
|
||||
"""Prevent subscript access."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot subscript SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Prevent direct comparison which might expose value."""
|
||||
if isinstance(other, SensitiveValue):
|
||||
# Compare encrypted bytes for equality check
|
||||
return self._encrypted_bytes == other._encrypted_bytes
|
||||
raise SensitiveAccessError(
|
||||
"Cannot compare SensitiveValue with non-SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value for comparison."
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Allow hashing based on encrypted bytes."""
|
||||
return hash(self._encrypted_bytes)
|
||||
|
||||
# Prevent JSON serialization
|
||||
def __json__(self) -> Any:
|
||||
"""Prevent JSON serialization."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot serialize SensitiveValue to JSON. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
# For Pydantic compatibility
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
|
||||
"""Prevent Pydantic from serializing without explicit get_value()."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot serialize SensitiveValue in Pydantic model. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value before serialization."
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user