mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-10 18:22:40 +00:00
Compare commits
14 Commits
fix/search
...
nikg/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aef819edee | ||
|
|
cf826871cd | ||
|
|
2d32326eca | ||
|
|
b2f9d12625 | ||
|
|
eba9f1e06e | ||
|
|
a6756a950b | ||
|
|
ddbc8fadc8 | ||
|
|
b9dce93a6f | ||
|
|
f3c3e29fb7 | ||
|
|
eaa1e94bed | ||
|
|
a5bb2d8130 | ||
|
|
dbbca1c4e5 | ||
|
|
ff92928b31 | ||
|
|
f452a777cc |
@@ -106,34 +106,13 @@ onyx-cli ask --json "What authentication methods do we support?"
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
| `citation_info` | Source citation with `citation_number` and `document_id` |
|
||||
|
||||
### Specify an agent
|
||||
|
||||
@@ -150,10 +129,6 @@ Uses a specific Onyx agent/persona instead of the default.
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,7 +316,6 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -419,7 +418,6 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -639,7 +637,6 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -694,7 +691,6 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -468,7 +468,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -707,7 +707,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
|
||||
178
.github/workflows/release-cli.yml
vendored
178
.github/workflows/release-cli.yml
vendored
@@ -26,7 +26,8 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -37,178 +38,3 @@ jobs:
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -22,11 +22,13 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
69
.github/workflows/storybook-deploy.yml
vendored
@@ -1,69 +0,0 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -544,8 +544,6 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -598,7 +596,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
|
||||
@@ -46,9 +46,7 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
vim && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -143,7 +141,6 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
@@ -167,13 +164,6 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""add timestamps to user table
|
||||
|
||||
Revision ID: 27fb147a843f
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-08 17:18:40.828644
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "27fb147a843f"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "updated_at")
|
||||
op.drop_column("user", "created_at")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -21,52 +22,59 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,15 +9,12 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -119,68 +116,6 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -21,13 +20,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -27,9 +26,7 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -38,7 +35,6 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -58,7 +54,6 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -68,7 +63,3 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -68,7 +68,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -80,11 +79,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -123,26 +117,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -126,16 +125,10 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -14,91 +14,67 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
return encoded_key
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_str.encode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_bytes.decode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
return decrypted_data.decode()
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
return versioned_encryption_fn(input_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -11,7 +12,6 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -132,61 +132,19 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
user_groups=[],
|
||||
is_public=is_public,
|
||||
is_public=True if user_file.user is None else False,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
for user_file in user_files
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -146,22 +145,10 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise ValueError(
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -115,6 +115,8 @@ def _extract_from_batch(
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
@@ -123,7 +125,8 @@ def _extract_from_batch(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
@@ -189,7 +192,9 @@ def extract_ids_from_runnable_connector(
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -40,7 +40,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -290,14 +289,6 @@ def _run_hierarchy_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
|
||||
@@ -48,15 +48,10 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -65,7 +60,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
@@ -585,12 +579,11 @@ def connector_pruning_generator_task(
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
@@ -599,14 +592,6 @@ def connector_pruning_generator_task(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -622,6 +607,7 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
@@ -678,43 +664,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
# --- Hierarchy node pruning ---
|
||||
live_node_ids = {n.id for n in upserted_nodes}
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
live_hierarchy_node_ids=live_node_ids,
|
||||
commit=True,
|
||||
)
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
reparented_nodes = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
if deleted_raw_ids:
|
||||
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
|
||||
if reparented_nodes:
|
||||
reparented_cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in reparented_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client, source, reparented_cache_entries
|
||||
)
|
||||
if stale_removed or deleted_raw_ids or reparented_nodes:
|
||||
task_logger.info(
|
||||
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
|
||||
f"stale_entries_removed={stale_removed} "
|
||||
f"nodes_deleted={len(deleted_raw_ids)} "
|
||||
f"nodes_reparented={len(reparented_nodes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -12,9 +12,9 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -43,9 +43,7 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -56,7 +54,6 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
@@ -794,12 +791,11 @@ def project_sync_user_file_impl(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
[user_file_id],
|
||||
db_session,
|
||||
eager_load_groups=global_version.is_ee_version(),
|
||||
)
|
||||
user_file = user_files[0] if user_files else None
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
@@ -827,21 +823,12 @@ def project_sync_user_file_impl(
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
|
||||
file_id_str = str(user_file.id)
|
||||
access_map = build_access_for_user_files([user_file])
|
||||
access = access_map.get(file_id_str)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=file_id_str,
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=(
|
||||
VespaDocumentFields(access=access)
|
||||
if access is not None
|
||||
else None
|
||||
),
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
|
||||
@@ -45,7 +45,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -588,14 +587,6 @@ def connector_document_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
|
||||
@@ -50,7 +50,6 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
@@ -981,10 +980,6 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -55,7 +54,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -1011,7 +1009,6 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1218,14 +1215,7 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,10 +68,6 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -96,12 +92,19 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
else:
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
@@ -285,9 +288,8 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -205,7 +204,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -228,23 +227,22 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,7 +1722,6 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,7 +476,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -485,8 +484,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -495,7 +492,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -511,7 +507,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -677,7 +672,6 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -759,7 +753,6 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -781,7 +774,6 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
|
||||
@@ -44,7 +44,6 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -90,7 +89,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id))))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -179,9 +178,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id),
|
||||
{"value": params.get("state", [None])[0]},
|
||||
encrypt=True,
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
)
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -13,7 +10,6 @@ from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -462,7 +458,7 @@ def get_all_hierarchy_nodes_for_source(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str, # noqa: ARG001
|
||||
user_email: str | None, # noqa: ARG001
|
||||
external_group_ids: list[str], # noqa: ARG001
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -489,7 +485,7 @@ def _get_accessible_hierarchy_nodes_for_source(
|
||||
def get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -624,154 +620,3 @@ def update_hierarchy_node_permissions(
|
||||
db_session.flush()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
hierarchy_node_ids: list[int],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts.
|
||||
|
||||
This records that the given cc_pair "owns" these hierarchy nodes. Used by
|
||||
indexing, pruning, and hierarchy-fetching paths.
|
||||
"""
|
||||
if not hierarchy_node_ids:
|
||||
return
|
||||
|
||||
_M = HierarchyNodeByConnectorCredentialPair
|
||||
stmt = pg_insert(_M).values(
|
||||
[
|
||||
{
|
||||
_M.hierarchy_node_id: node_id,
|
||||
_M.connector_id: connector_id,
|
||||
_M.credential_id: credential_id,
|
||||
}
|
||||
for node_id in hierarchy_node_ids
|
||||
]
|
||||
)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
db_session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
live_hierarchy_node_ids: set[int],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Delete join-table rows for this cc_pair that are NOT in the live set.
|
||||
|
||||
If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted
|
||||
(i.e. the connector no longer has any hierarchy nodes). Callers that want a
|
||||
no-op when there are no live nodes must guard before calling.
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
stmt = delete(HierarchyNodeByConnectorCredentialPair).where(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
if live_hierarchy_node_ids:
|
||||
stmt = stmt.where(
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_(
|
||||
live_hierarchy_node_ids
|
||||
)
|
||||
)
|
||||
|
||||
result: CursorResult = db_session.execute(stmt) # type: ignore[assignment]
|
||||
deleted = result.rowcount
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif deleted:
|
||||
db_session.flush()
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def delete_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[str]:
|
||||
"""Delete hierarchy nodes for a source that have zero cc_pair associations.
|
||||
|
||||
SOURCE-type nodes are excluded (they are synthetic roots).
|
||||
|
||||
Returns the list of raw_node_ids that were deleted (for cache eviction).
|
||||
"""
|
||||
# Find orphaned nodes: no rows in the join table
|
||||
orphan_stmt = (
|
||||
select(HierarchyNode.id, HierarchyNode.raw_node_id)
|
||||
.outerjoin(
|
||||
HierarchyNodeByConnectorCredentialPair,
|
||||
HierarchyNode.id
|
||||
== HierarchyNodeByConnectorCredentialPair.hierarchy_node_id,
|
||||
)
|
||||
.where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None),
|
||||
)
|
||||
)
|
||||
orphans = db_session.execute(orphan_stmt).all()
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
orphan_ids = [row[0] for row in orphans]
|
||||
deleted_raw_ids = [row[1] for row in orphans]
|
||||
|
||||
db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids)))
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return deleted_raw_ids
|
||||
|
||||
|
||||
def reparent_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[HierarchyNode]:
|
||||
"""Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node.
|
||||
|
||||
After pruning deletes stale nodes, their former children get parent_id=NULL
|
||||
via the SET NULL cascade. This function points them back to the SOURCE root.
|
||||
|
||||
Returns the reparented HierarchyNode objects (with updated parent_id)
|
||||
so callers can refresh downstream caches.
|
||||
"""
|
||||
source_node = get_source_hierarchy_node(db_session, source)
|
||||
if not source_node:
|
||||
return []
|
||||
|
||||
stmt = select(HierarchyNode).where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.parent_id.is_(None),
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
)
|
||||
orphans = list(db_session.execute(stmt).scalars().all())
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
for node in orphans:
|
||||
node.parent_id = source_node.id
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return orphans
|
||||
|
||||
@@ -270,35 +270,10 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if default_model.id in removed_ids:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -563,6 +538,7 @@ def fetch_default_model(
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
@@ -838,30 +814,44 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default.
|
||||
# We flush (but don't commit) so that _update_default_model can see the new
|
||||
# model rows, then commit everything atomically to avoid a window where the
|
||||
# old default is invisible but still pointed-to.
|
||||
db_session.flush()
|
||||
db_session.commit()
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
current_default_name = db_session.scalar(
|
||||
select(ModelConfiguration.name)
|
||||
.join(
|
||||
LLMModelFlow,
|
||||
LLMModelFlow.model_configuration_id == ModelConfiguration.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider.id,
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
current_default
|
||||
and current_default.llm_provider_id == provider.id
|
||||
and current_default.name != recommended_default.name
|
||||
current_default_name is not None
|
||||
and current_default_name != recommended_default.name
|
||||
):
|
||||
_update_default_model__no_commit(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
try:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Recommended default model '%s' not found "
|
||||
"for provider_id=%s; skipping default update.",
|
||||
recommended_default.name,
|
||||
provider.id,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -992,7 +982,7 @@ def update_model_configuration__no_commit(
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model__no_commit(
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
@@ -1030,14 +1020,6 @@ def _update_default_model__no_commit(
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
@@ -37,11 +36,9 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
@@ -120,50 +117,10 @@ class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class _EncryptedBase(TypeDecorator):
|
||||
"""Base for encrypted column types that wrap values in SensitiveValue."""
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def wrap_raw(self, value: Any) -> SensitiveValue:
|
||||
"""Encrypt a raw value and wrap it in SensitiveValue.
|
||||
|
||||
Called by the attribute set event so the Python-side type is always
|
||||
SensitiveValue, regardless of whether the value was loaded from the DB
|
||||
or assigned in application code.
|
||||
"""
|
||||
if self._is_json:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f"EncryptedJson column expected dict, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = json.dumps(value)
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"EncryptedString column expected str, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=encrypt_string_to_bytes(raw_str),
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=self._is_json,
|
||||
)
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
@@ -187,9 +144,20 @@ class EncryptedString(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
_is_json: bool = True
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
@@ -197,7 +165,9 @@ class EncryptedJson(_EncryptedBase):
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
@@ -214,40 +184,14 @@ class EncryptedJson(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
_REGISTERED_ATTRS: set[str] = set()
|
||||
|
||||
|
||||
@event.listens_for(Mapper, "mapper_configured")
|
||||
def _register_sensitive_value_set_events(
|
||||
mapper: Mapper,
|
||||
class_: type,
|
||||
) -> None:
|
||||
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, _EncryptedBase):
|
||||
col_type = col.type
|
||||
attr = getattr(class_, prop.key)
|
||||
|
||||
# Guard against double-registration (e.g. if mapper is
|
||||
# re-configured in test setups)
|
||||
attr_key = f"{class_.__qualname__}.{prop.key}"
|
||||
if attr_key in _REGISTERED_ATTRS:
|
||||
continue
|
||||
_REGISTERED_ATTRS.add(attr_key)
|
||||
|
||||
@event.listens_for(attr, "set", retval=True)
|
||||
def _wrap_value(
|
||||
target: Any, # noqa: ARG001
|
||||
value: Any,
|
||||
oldvalue: Any, # noqa: ARG001
|
||||
initiator: Any, # noqa: ARG001
|
||||
_col_type: _EncryptedBase = col_type,
|
||||
) -> Any:
|
||||
if value is not None and not isinstance(value, SensitiveValue):
|
||||
return _col_type.wrap_raw(value)
|
||||
return value
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -336,6 +280,16 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
@@ -2426,38 +2380,6 @@ class SyncRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class HierarchyNodeByConnectorCredentialPair(Base):
|
||||
"""Tracks which cc_pairs reference each hierarchy node.
|
||||
|
||||
During pruning, stale entries are removed for the current cc_pair.
|
||||
Hierarchy nodes with zero remaining entries are then deleted.
|
||||
"""
|
||||
|
||||
__tablename__ = "hierarchy_node_by_connector_credential_pair"
|
||||
|
||||
hierarchy_node_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
connector_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
Index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
|
||||
@@ -205,9 +205,7 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -215,7 +213,6 @@ def update_persona_access(
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -236,7 +233,6 @@ def update_persona_access(
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -244,10 +240,6 @@ def update_persona_access(
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
@@ -859,24 +851,6 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_persona_user_files_for_sync(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When persona sharing changes, mark all of its user files for sync
|
||||
so that their ACLs get updated in the vector DB."""
|
||||
persona = (
|
||||
db_session.query(Persona)
|
||||
.options(selectinload(Persona.user_files))
|
||||
.filter(Persona.id == persona_id)
|
||||
.first()
|
||||
)
|
||||
if not persona:
|
||||
return
|
||||
file_ids = [uf.id for uf in persona.user_files]
|
||||
_mark_files_need_persona_sync(db_session, file_ids)
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -3,11 +3,9 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
@@ -120,31 +118,3 @@ def get_file_ids_by_user_file_ids(
|
||||
) -> list[str]:
|
||||
user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
return [user_file.file_id for user_file in user_files]
|
||||
|
||||
|
||||
def fetch_user_files_with_access_relationships(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
eager_load_groups: bool = False,
|
||||
) -> list[UserFile]:
|
||||
"""Fetch user files with the owner and assistant relationships
|
||||
eagerly loaded (needed for computing access control).
|
||||
|
||||
When eager_load_groups is True, Persona.groups is also loaded so that
|
||||
callers can extract user-group names without a second DB round-trip."""
|
||||
persona_sub_options = [
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.user),
|
||||
]
|
||||
if eager_load_groups:
|
||||
persona_sub_options.append(selectinload(Persona.groups))
|
||||
|
||||
return (
|
||||
db_session.query(UserFile)
|
||||
.options(
|
||||
joinedload(UserFile.user),
|
||||
selectinload(UserFile.assistants).options(*persona_sub_options),
|
||||
)
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -24,6 +25,7 @@ from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -162,7 +164,13 @@ def _get_accepted_user_where_clause(
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
|
||||
personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name
|
||||
where_clause.append(
|
||||
or_(
|
||||
email_col.ilike(f"%{email_filter_string}%"),
|
||||
personal_name_col.ilike(f"%{email_filter_string}%"),
|
||||
)
|
||||
)
|
||||
|
||||
if roles_filter:
|
||||
where_clause.append(User.role.in_(roles_filter))
|
||||
@@ -358,3 +366,28 @@ def delete_user_from_db(
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
remove_user_from_invited_users(user_to_delete.email)
|
||||
|
||||
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
UserGroup.name,
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
result[user_id].append((group_id, group_name))
|
||||
return result
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -698,6 +698,41 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -723,53 +758,41 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
)
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,8 +23,11 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -34,22 +37,30 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -62,12 +73,16 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -89,7 +104,8 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -98,14 +114,16 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -117,8 +135,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -129,7 +147,8 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -141,94 +160,73 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# filter_str += _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# ))
|
||||
# )
|
||||
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# ))
|
||||
# )
|
||||
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class OnyxErrorCode(Enum):
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the detail.
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"detail": message or self.code,
|
||||
"message": message or self.code,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "detail": "..."}`` shape.
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -37,21 +37,21 @@ class OnyxError(Exception):
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
detail: Human-readable detail (defaults to the error code string).
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
detail: str | None = None,
|
||||
message: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_detail = detail or error_code.code
|
||||
super().__init__(resolved_detail)
|
||||
resolved_message = message or error_code.code
|
||||
super().__init__(resolved_message)
|
||||
self.error_code = error_code
|
||||
self.detail = resolved_detail
|
||||
self.message = resolved_message
|
||||
self._status_code_override = status_code_override
|
||||
|
||||
@property
|
||||
@@ -73,11 +73,11 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
content=exc.error_code.detail(exc.message),
|
||||
)
|
||||
|
||||
@@ -19,16 +19,12 @@ class OnyxMimeTypes:
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-log",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
"text/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import abc
|
||||
from typing import cast
|
||||
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -8,19 +7,6 @@ class KvKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_str(val: JSON_ro) -> str:
|
||||
"""Unwrap a string stored as {"value": str} in the encrypted KV store.
|
||||
Also handles legacy plain-string values cached in Redis."""
|
||||
if isinstance(val, dict):
|
||||
try:
|
||||
return cast(str, val["value"])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Expected dict with 'value' key, got keys: {list(val.keys())}"
|
||||
)
|
||||
return cast(str, val)
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -27,14 +26,6 @@ async def search_indexed_documents(
|
||||
Use this tool for information that is not public knowledge and specific to the user,
|
||||
their team, their work, or their organization/company.
|
||||
|
||||
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
|
||||
on every call, consuming tokens and adding latency.
|
||||
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
|
||||
but this should still be sufficient for most use cases. CE mode functionality should be swapped
|
||||
when a dedicated CE search endpoint is implemented.
|
||||
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
@@ -120,72 +111,47 @@ async def search_indexed_documents(
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
endpoint,
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get(error_key):
|
||||
if result.get("error"):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get(error_key),
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
documents = [
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
"content",
|
||||
"source_type",
|
||||
"link",
|
||||
"score",
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("search_docs", [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
# backend retrieves fewer documents than requested.
|
||||
documents = documents[:limit]
|
||||
|
||||
logger.info(
|
||||
f"Onyx MCP Server: Internal search returned {len(documents)} results"
|
||||
@@ -194,6 +160,7 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
|
||||
@@ -16,7 +16,6 @@ Cache Strategy:
|
||||
using only the SOURCE-type node as the ancestor
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -205,30 +204,6 @@ def cache_hierarchy_nodes_batch(
|
||||
redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
def evict_hierarchy_nodes_from_cache(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_node_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove specific hierarchy nodes from the Redis cache.
|
||||
|
||||
Deletes entries from both the parent-chain hash and the raw_id→node_id hash.
|
||||
"""
|
||||
if not raw_node_ids:
|
||||
return
|
||||
|
||||
cache_key = _cache_key(source)
|
||||
raw_id_key = _raw_id_cache_key(source)
|
||||
|
||||
# Look up node_ids so we can remove them from the parent-chain hash
|
||||
raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids))
|
||||
node_id_strs = [v for v in raw_values if v is not None]
|
||||
|
||||
if node_id_strs:
|
||||
redis_client.hdel(cache_key, *node_id_strs)
|
||||
redis_client.hdel(raw_id_key, *raw_node_ids)
|
||||
|
||||
|
||||
def get_node_id_from_raw_id(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
|
||||
@@ -1905,7 +1905,7 @@ def get_connector_by_id(
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
@@ -1918,7 +1918,7 @@ def submit_connector_request(
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email
|
||||
user_email = user.email if user else None
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
|
||||
@@ -57,6 +57,9 @@ def list_messages(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageListResponse:
|
||||
"""Get all messages for a build session."""
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
messages = session_manager.list_messages(session_id, user.id)
|
||||
|
||||
@@ -54,14 +54,18 @@ def _require_opensearch(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
|
||||
def _get_user_access_info(
|
||||
user: User | None, db_session: Session
|
||||
) -> tuple[str | None, list[str]]:
|
||||
if not user:
|
||||
return None, []
|
||||
return user.email, get_user_external_group_ids(db_session, user)
|
||||
|
||||
|
||||
@router.get(HIERARCHY_NODES_LIST_PATH)
|
||||
def list_accessible_hierarchy_nodes(
|
||||
source: DocumentSource,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodesResponse:
|
||||
_require_opensearch(db_session)
|
||||
@@ -88,7 +92,7 @@ def list_accessible_hierarchy_nodes(
|
||||
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
|
||||
def list_accessible_hierarchy_node_documents(
|
||||
documents_request: HierarchyNodeDocumentsRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodeDocumentsResponse:
|
||||
_require_opensearch(db_session)
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
|
||||
@router.get("/servers", response_model=MCPServersResponse)
|
||||
def get_mcp_servers_for_user(
|
||||
db: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> MCPServersResponse:
|
||||
"""List all MCP servers for use in agent configuration and chat UI.
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -37,38 +35,6 @@ def get_safe_filename(upload: UploadFile) -> str:
|
||||
return upload.filename
|
||||
|
||||
|
||||
def get_upload_size_bytes(upload: UploadFile) -> int | None:
|
||||
"""Best-effort file size in bytes without consuming the stream."""
|
||||
if upload.size is not None:
|
||||
return upload.size
|
||||
|
||||
try:
|
||||
current_pos = upload.file.tell()
|
||||
upload.file.seek(0, 2)
|
||||
size = upload.file.tell()
|
||||
upload.file.seek(current_pos)
|
||||
return size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not determine upload size via stream seek "
|
||||
f"(filename='{get_safe_filename(upload)}', "
|
||||
f"error_type={type(e).__name__}, error={e})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
|
||||
"""Return True when upload size is known and exceeds max_bytes."""
|
||||
size_bytes = get_upload_size_bytes(upload)
|
||||
if size_bytes is None:
|
||||
logger.warning(
|
||||
"Could not determine upload size; skipping size-limit check for "
|
||||
f"'{get_safe_filename(upload)}'"
|
||||
)
|
||||
return False
|
||||
return size_bytes > max_bytes
|
||||
|
||||
|
||||
# Guard against extremely large images
|
||||
Image.MAX_IMAGE_PIXELS = 12000 * 12000
|
||||
|
||||
@@ -193,18 +159,6 @@ def categorize_uploaded_files(
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
|
||||
@@ -65,7 +65,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -446,17 +445,16 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# When transitioning to auto mode, preserve existing model configurations
|
||||
# so the upsert doesn't try to delete them (which would trip the default
|
||||
# model protection guard). sync_auto_mode_models will handle the model
|
||||
# lifecycle afterward — adding new models, hiding removed ones, and
|
||||
# updating the default. This is safe even if sync fails: the provider
|
||||
# keeps its old models and default rather than losing them.
|
||||
if transitioning_to_auto_mode and existing_provider:
|
||||
llm_provider_upsert_request.model_configurations = [
|
||||
ModelConfigurationUpsertRequest.from_model(mc)
|
||||
for mc in existing_provider.model_configurations
|
||||
]
|
||||
# Before the upsert, check if this provider currently owns the global
|
||||
# CHAT default. The upsert may cascade-delete model_configurations
|
||||
# (and their flow mappings), so we need to remember this beforehand.
|
||||
was_default_provider = False
|
||||
if existing_provider and transitioning_to_auto_mode:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
was_default_provider = (
|
||||
current_default is not None
|
||||
and current_default.llm_provider_id == existing_provider.id
|
||||
)
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
@@ -470,6 +468,7 @@ def put_llm_provider(
|
||||
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
@@ -479,6 +478,20 @@ def put_llm_provider(
|
||||
updated_provider,
|
||||
config,
|
||||
)
|
||||
|
||||
# If this provider was the default before the transition,
|
||||
# restore the default using the recommended model.
|
||||
if was_default_provider:
|
||||
recommended = config.get_default_model(
|
||||
llm_provider_upsert_request.provider
|
||||
)
|
||||
if recommended:
|
||||
update_default_provider(
|
||||
provider_id=updated_provider.id,
|
||||
model_name=recommended.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Refresh result with synced models
|
||||
result = LLMProviderView.from_model(updated_provider)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -18,6 +19,7 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.anonymous_user import fetch_anonymous_user_info
|
||||
@@ -67,6 +69,7 @@ from onyx.db.user_preferences import update_user_role
|
||||
from onyx.db.user_preferences import update_user_shortcut_enabled
|
||||
from onyx.db.user_preferences import update_user_temperature_override_enabled
|
||||
from onyx.db.user_preferences import update_user_theme_preference
|
||||
from onyx.db.users import batch_get_user_groups
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.db.users import get_page_of_filtered_users
|
||||
@@ -98,6 +101,7 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -203,9 +207,32 @@ def list_accepted_users(
|
||||
total_items=0,
|
||||
)
|
||||
|
||||
user_ids = [user.id for user in filtered_accepted_users]
|
||||
groups_by_user = batch_get_user_groups(db_session, user_ids)
|
||||
|
||||
# Batch-fetch SCIM mappings to mark synced users
|
||||
scim_synced_ids: set[UUID] = set()
|
||||
try:
|
||||
from onyx.db.models import ScimUserMapping
|
||||
|
||||
scim_mappings = db_session.scalars(
|
||||
select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
scim_synced_ids = set(scim_mappings)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
|
||||
FullUserSnapshot.from_user_model(
|
||||
user,
|
||||
groups=[
|
||||
UserGroupInfo(id=gid, name=gname)
|
||||
for gid, gname in groups_by_user.get(user.id, [])
|
||||
],
|
||||
is_scim_synced=user.id in scim_synced_ids,
|
||||
)
|
||||
for user in filtered_accepted_users
|
||||
],
|
||||
total_items=total_accepted_users_count,
|
||||
)
|
||||
@@ -269,24 +296,10 @@ def list_all_users(
|
||||
if accepted_page is None or invited_page is None or slack_users_page is None:
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
FullUserSnapshot.from_user_model(user) for user in accepted_users
|
||||
],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
FullUserSnapshot.from_user_model(user) for user in slack_users
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
||||
accepted_pages=1,
|
||||
@@ -296,26 +309,10 @@ def list_all_users(
|
||||
|
||||
# Otherwise, return paginated results
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
accepted=[FullUserSnapshot.from_user_model(user) for user in accepted_users][
|
||||
accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE
|
||||
],
|
||||
slack_users=[FullUserSnapshot.from_user_model(user) for user in slack_users][
|
||||
slack_users_page
|
||||
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
||||
* USERS_PAGE_SIZE
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
@@ -31,21 +32,41 @@ class MinimalUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class UserGroupInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
groups: list[UserGroupInfo]
|
||||
is_scim_synced: bool
|
||||
|
||||
@classmethod
|
||||
def from_user_model(cls, user: User) -> "FullUserSnapshot":
|
||||
def from_user_model(
|
||||
cls,
|
||||
user: User,
|
||||
groups: list[UserGroupInfo] | None = None,
|
||||
is_scim_synced: bool = False,
|
||||
) -> "FullUserSnapshot":
|
||||
return cls(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
groups=groups or [],
|
||||
is_scim_synced=is_scim_synced,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.citation_utils import extract_citation_order_from_text
|
||||
@@ -22,9 +20,7 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderResult
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderStart
|
||||
@@ -184,37 +180,24 @@ def create_custom_tool_packets(
|
||||
tab_index: int = 0,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
error: CustomToolErrorInfo | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
tool_id: int | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
tool_id=tool_id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -674,55 +657,13 @@ def translate_assistant_message_to_packets(
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
# Try to parse as structured CustomToolCallSummary JSON
|
||||
custom_data: dict | list | str | int | float | bool | None = (
|
||||
tool_call.tool_call_response
|
||||
)
|
||||
custom_error: CustomToolErrorInfo | None = None
|
||||
custom_response_type = "text"
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_call.tool_call_response)
|
||||
if isinstance(parsed, dict) and "tool_name" in parsed:
|
||||
custom_data = parsed.get("tool_result")
|
||||
custom_response_type = parsed.get(
|
||||
"response_type", "text"
|
||||
)
|
||||
if parsed.get("error"):
|
||||
custom_error = CustomToolErrorInfo(
|
||||
**parsed["error"]
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
ValidationError,
|
||||
):
|
||||
pass
|
||||
|
||||
custom_file_ids: list[str] | None = None
|
||||
if custom_response_type in ("image", "csv") and isinstance(
|
||||
custom_data, dict
|
||||
):
|
||||
custom_file_ids = custom_data.get("file_ids")
|
||||
custom_data = None
|
||||
|
||||
custom_args = {
|
||||
k: v
|
||||
for k, v in (tool_call.tool_call_arguments or {}).items()
|
||||
if k != "requestBody"
|
||||
}
|
||||
turn_tool_packets.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type=custom_response_type,
|
||||
response_type="text",
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
data=custom_data,
|
||||
file_ids=custom_file_ids,
|
||||
error=custom_error,
|
||||
tool_args=custom_args if custom_args else None,
|
||||
tool_id=tool_call.tool_id,
|
||||
data=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ class StreamingType(Enum):
|
||||
PYTHON_TOOL_START = "python_tool_start"
|
||||
PYTHON_TOOL_DELTA = "python_tool_delta"
|
||||
CUSTOM_TOOL_START = "custom_tool_start"
|
||||
CUSTOM_TOOL_ARGS = "custom_tool_args"
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta"
|
||||
FILE_READER_START = "file_reader_start"
|
||||
FILE_READER_RESULT = "file_reader_result"
|
||||
@@ -42,7 +41,6 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -247,20 +245,6 @@ class CustomToolStart(BaseObj):
|
||||
type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
|
||||
|
||||
class CustomToolArgs(BaseObj):
|
||||
type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolErrorInfo(BaseModel):
|
||||
is_auth_error: bool = False
|
||||
status_code: int
|
||||
message: str
|
||||
|
||||
|
||||
# The allowed streamed packets for a custom tool
|
||||
@@ -268,22 +252,11 @@ class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
@@ -393,7 +366,6 @@ PacketObj = Union[
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolArgs,
|
||||
CustomToolDelta,
|
||||
FileReaderStart,
|
||||
FileReaderResult,
|
||||
@@ -407,7 +379,6 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -8,6 +8,8 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
@@ -163,6 +165,39 @@ def create_image_generation_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_custom_tool_packets(
|
||||
tool_name: str,
|
||||
response_type: str,
|
||||
turn_index: int,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
|
||||
@@ -78,7 +78,6 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
|
||||
@@ -3,7 +3,6 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -51,7 +50,6 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
return settings
|
||||
|
||||
@@ -275,13 +275,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to upsert LLM provider during setup: %s", e)
|
||||
return
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -56,23 +56,3 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,7 +92,3 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
@@ -62,7 +61,6 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
|
||||
tool_result: Any # The response data
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallKickoff(BaseModel):
|
||||
|
||||
@@ -15,9 +15,7 @@ from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
@@ -141,7 +139,7 @@ class CustomTool(Tool[None]):
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolStart(tool_name=self._name, tool_id=self._id),
|
||||
obj=CustomToolStart(tool_name=self._name),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -151,8 +149,10 @@ class CustomTool(Tool[None]):
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
# Build path params
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||
param_name = path_param_schema["name"]
|
||||
if param_name not in llm_kwargs:
|
||||
@@ -165,7 +165,6 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
path_params[param_name] = llm_kwargs[param_name]
|
||||
|
||||
# Build query params
|
||||
query_params = {}
|
||||
for query_param_schema in self._method_spec.get_query_param_schemas():
|
||||
if query_param_schema["name"] in llm_kwargs:
|
||||
@@ -173,20 +172,6 @@ class CustomTool(Tool[None]):
|
||||
query_param_schema["name"]
|
||||
]
|
||||
|
||||
# Emit args packet (path + query params only, no request body)
|
||||
tool_args = {**path_params, **query_params}
|
||||
if tool_args:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolArgs(
|
||||
tool_name=self._name,
|
||||
tool_args=tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
url = self._method_spec.build_url(self._base_url, path_params, query_params)
|
||||
method = self._method_spec.method
|
||||
|
||||
@@ -195,18 +180,6 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
# Detect HTTP errors — only 401/403 are flagged as auth errors
|
||||
error_info: CustomToolErrorInfo | None = None
|
||||
if response.status_code in (401, 403):
|
||||
error_info = CustomToolErrorInfo(
|
||||
is_auth_error=True,
|
||||
status_code=response.status_code,
|
||||
message=f"{self._name} action failed because of authentication error",
|
||||
)
|
||||
logger.warning(
|
||||
f"Auth error from custom tool '{self._name}': HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
tool_result: Any
|
||||
response_type: str
|
||||
file_ids: List[str] | None = None
|
||||
@@ -249,11 +222,9 @@ class CustomTool(Tool[None]):
|
||||
placement=placement,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
tool_id=self._id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error_info,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -265,7 +236,6 @@ class CustomTool(Tool[None]):
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
error=error_info,
|
||||
),
|
||||
llm_facing_response=llm_facing_response,
|
||||
)
|
||||
|
||||
@@ -376,8 +376,3 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -11,20 +11,16 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT encrypt called with explicit key — key ignored.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support decryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT decrypt called with explicit key — key ignored.")
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
@@ -90,15 +86,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(intput_str, key=key)
|
||||
return versioned_encryption_fn(intput_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(intput_bytes, key=key)
|
||||
return versioned_decryption_fn(intput_bytes)
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
@@ -1,427 +0,0 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if (
|
||||
prev_val
|
||||
and len(cur_val) >= len(prev_val)
|
||||
and cur_val[len(prev_val) - 1] != prev_val[-1]
|
||||
):
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if (
|
||||
prev
|
||||
and len(current) >= len(prev)
|
||||
and current[len(prev) - 1] != prev[-1]
|
||||
):
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
@@ -1,514 +0,0 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -128,8 +128,6 @@ class SensitiveValue(Generic[T]):
|
||||
value = self._decrypt()
|
||||
|
||||
if not apply_mask:
|
||||
# Callers must not mutate the returned dict — doing so would
|
||||
# desync the cache from the encrypted bytes and the DB.
|
||||
return value
|
||||
|
||||
# Apply masking
|
||||
@@ -176,20 +174,18 @@ class SensitiveValue(Generic[T]):
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compare SensitiveValues by their decrypted content."""
|
||||
# NOTE: if you attempt to compare a string/dict to a SensitiveValue,
|
||||
# this comparison will return NotImplemented, which then evaluates to False.
|
||||
# This is the convention and required for SQLAlchemy's attribute tracking.
|
||||
if not isinstance(other, SensitiveValue):
|
||||
return NotImplemented
|
||||
return self._decrypt() == other._decrypt()
|
||||
"""Prevent direct comparison which might expose value."""
|
||||
if isinstance(other, SensitiveValue):
|
||||
# Compare encrypted bytes for equality check
|
||||
return self._encrypted_bytes == other._encrypted_bytes
|
||||
raise SensitiveAccessError(
|
||||
"Cannot compare SensitiveValue with non-SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value for comparison."
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on decrypted content."""
|
||||
value = self._decrypt()
|
||||
if isinstance(value, dict):
|
||||
return hash(json.dumps(value, sort_keys=True))
|
||||
return hash(value)
|
||||
"""Allow hashing based on encrypted bytes."""
|
||||
return hash(self._encrypted_bytes)
|
||||
|
||||
# Prevent JSON serialization
|
||||
def __json__(self) -> Any:
|
||||
|
||||
@@ -2,6 +2,7 @@ import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -14,7 +15,6 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -25,7 +25,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
_CACHED_INSTANCE_DOMAIN: str | None = None
|
||||
@@ -63,10 +62,10 @@ def get_or_generate_uuid() -> str:
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_UUID = unwrap_str(kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
_CACHED_UUID = str(uuid.uuid4())
|
||||
kv_store.store(KV_CUSTOMER_UUID_KEY, {"value": _CACHED_UUID}, encrypt=True)
|
||||
kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True)
|
||||
|
||||
return _CACHED_UUID
|
||||
|
||||
@@ -80,16 +79,14 @@ def _get_or_generate_instance_domain() -> str | None: #
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = unwrap_str(kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
kv_store.store(
|
||||
KV_INSTANCE_DOMAIN_KEY,
|
||||
{"value": _CACHED_INSTANCE_DOMAIN},
|
||||
encrypt=True,
|
||||
KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True
|
||||
)
|
||||
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
@@ -24,9 +24,6 @@ class OnyxVersion:
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def unset_ee(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
@@ -1,93 +1,48 @@
|
||||
"""Decrypt a raw hex-encoded credential value.
|
||||
|
||||
Usage:
|
||||
python -m scripts.decrypt <hex_value>
|
||||
python -m scripts.decrypt <hex_value> --key "my-encryption-key"
|
||||
python -m scripts.decrypt <hex_value> --key ""
|
||||
|
||||
Pass --key "" to skip decryption and just decode the raw bytes as UTF-8.
|
||||
Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
|
||||
|
||||
def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value.
|
||||
def decrypt_raw_credential(encrypted_value: str) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value
|
||||
|
||||
Args:
|
||||
encrypted_value: The hex-encoded encrypted credential value.
|
||||
key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET,
|
||||
empty string means just decode as UTF-8.
|
||||
encrypted_value: The hex encoded encrypted credential value
|
||||
"""
|
||||
# Strip common hex prefixes
|
||||
if encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
elif encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
print(encrypted_value)
|
||||
|
||||
try:
|
||||
raw_bytes = binascii.unhexlify(encrypted_value)
|
||||
# If string starts with 'x', remove it as it's just a prefix indicating hex
|
||||
if encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
elif encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
|
||||
# Convert hex string to bytes
|
||||
encrypted_bytes = binascii.unhexlify(encrypted_value)
|
||||
|
||||
# Decrypt the bytes
|
||||
decrypted_str = decrypt_bytes_to_string(encrypted_bytes)
|
||||
|
||||
# Parse and pretty print the decrypted JSON
|
||||
decrypted_json = json.loads(decrypted_str)
|
||||
print("Decrypted credential value:")
|
||||
print(json.dumps(decrypted_json, indent=2))
|
||||
|
||||
except binascii.Error:
|
||||
print("Error: Invalid hex-encoded string")
|
||||
sys.exit(1)
|
||||
print("Error: Invalid hex encoded string")
|
||||
|
||||
if key == "":
|
||||
# Empty key → just decode as UTF-8, no decryption
|
||||
try:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Error decoding bytes as UTF-8: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(key)
|
||||
try:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key)
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Decrypted raw value (not JSON): {e}")
|
||||
|
||||
# Try to pretty-print as JSON, otherwise print raw
|
||||
try:
|
||||
parsed = json.loads(decrypted_str)
|
||||
print(json.dumps(parsed, indent=2))
|
||||
except json.JSONDecodeError:
|
||||
print(decrypted_str)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Decrypt a hex-encoded credential value."
|
||||
)
|
||||
parser.add_argument(
|
||||
"value",
|
||||
help="Hex-encoded encrypted value to decrypt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
default=None,
|
||||
help=(
|
||||
"Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. "
|
||||
'Pass "" (empty) to just decode as UTF-8 without decryption.'
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
global_version.set_ee()
|
||||
decrypt_raw_credential(args.value, key=args.key)
|
||||
global_version.unset_ee()
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python decrypt.py <hex_encoded_encrypted_value>")
|
||||
sys.exit(1)
|
||||
|
||||
encrypted_value = sys.argv[1]
|
||||
decrypt_raw_credential(encrypted_value)
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Decrypts all encrypted columns using the old key (or raw decode if the old key
|
||||
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Usage (docker):
|
||||
docker exec -it onyx-api_server-1 \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Usage (kubernetes):
|
||||
kubectl exec -it <pod> -- \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Omit --old-key (or pass "") if secrets were not previously encrypted.
|
||||
|
||||
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
|
||||
or --all-tenants to iterate every tenant.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
|
||||
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
|
||||
|
||||
|
||||
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
|
||||
print(f"Re-encrypting secrets for tenant: {tenant_id}")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
|
||||
|
||||
if results:
|
||||
for col, count in results.items():
|
||||
print(
|
||||
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
|
||||
)
|
||||
else:
|
||||
print("No rows needed re-encryption.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Re-encrypt secrets under the current encryption key."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old-key",
|
||||
default=None,
|
||||
help="Previous encryption key. Omit or pass empty string if not applicable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be re-encrypted without making changes.",
|
||||
)
|
||||
|
||||
tenant_group = parser.add_mutually_exclusive_group()
|
||||
tenant_group.add_argument(
|
||||
"--tenant-id",
|
||||
default=None,
|
||||
help="Target a specific tenant schema.",
|
||||
)
|
||||
tenant_group.add_argument(
|
||||
"--all-tenants",
|
||||
action="store_true",
|
||||
help="Iterate all tenants.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_key = args.old_key if args.old_key else None
|
||||
|
||||
global_version.set_ee()
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN — no changes will be made")
|
||||
|
||||
if args.all_tenants:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
print(f"Found {len(tenant_ids)} tenant(s)")
|
||||
failed_tenants: list[str] = []
|
||||
for tid in tenant_ids:
|
||||
try:
|
||||
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
print(f" ERROR for tenant {tid}: {e}")
|
||||
failed_tenants.append(tid)
|
||||
if failed_tenants:
|
||||
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,71 +0,0 @@
|
||||
# Backend Tests
|
||||
|
||||
## Test Types
|
||||
|
||||
There are four test categories, ordered by increasing scope:
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
|
||||
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
|
||||
logic (e.g. citation processing, encryption).
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
|
||||
|
||||
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
|
||||
application containers are not. Tests call functions directly and can mock selectively.
|
||||
|
||||
Use when you need a real database or real API calls but want control over setup.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
|
||||
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright / E2E Tests (`web/tests/e2e/`)
|
||||
|
||||
Full stack including web server. Use for frontend-backend coordination.
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
## Shared Fixtures
|
||||
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Enables EE mode for a test, with proper teardown and cache clearing.
|
||||
|
||||
```python
|
||||
# Whole file (in a test module, NOT in conftest.py)
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Whole directory — add an autouse wrapper to the directory's conftest.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
|
||||
# Single test
|
||||
def test_something(enable_ee: None) -> None: ...
|
||||
```
|
||||
|
||||
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
|
||||
directory — it only affects tests defined in the conftest itself (which is none).
|
||||
Use the autouse fixture wrapper pattern shown above instead.
|
||||
|
||||
Do NOT inline `global_version.set_ee()` — always use the fixture.
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Root conftest — shared fixtures available to all test directories."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def enable_ee() -> Generator[None, None, None]:
|
||||
"""Temporarily enable EE mode for a single test.
|
||||
|
||||
Restores the previous EE state and clears the versioned-implementation
|
||||
cache on teardown so state doesn't leak between tests.
|
||||
"""
|
||||
was_ee = global_version.is_ee_version()
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
if not was_ee:
|
||||
global_version.unset_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
def test_confluence_connector_permissions(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
confluence_connector: ConfluenceConnector,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
|
||||
def test_confluence_connector_restriction_handling(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
mock_db_provider_class: MagicMock,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Test space key
|
||||
test_space_key = "DailyPermS"
|
||||
|
||||
@@ -4,6 +4,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
@@ -12,3 +14,14 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
@@ -48,7 +48,7 @@ def test_gitlab_connector_basic(gitlab_connector: GitlabConnector) -> None:
|
||||
|
||||
# --- Specific Document Details to Validate ---
|
||||
target_mr_id = f"https://{gitlab_base_url}/{project_path}/-/merge_requests/1"
|
||||
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/work_items/2"
|
||||
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/issues/2"
|
||||
target_code_file_semantic_id = "README.md"
|
||||
# ---
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ def _build_connector(
|
||||
|
||||
def test_gdrive_perm_sync_with_real_data(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
|
||||
@@ -17,7 +19,16 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.teams.models import TeamsThread
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
@@ -166,9 +168,18 @@ def test_slim_docs_retrieval_from_teams_connector(
|
||||
_assert_is_valid_external_access(external_access=slim_doc.external_access)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for perm sync tests to work since
|
||||
perm syncing is an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
yield
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_perm_sync(
|
||||
teams_connector: TeamsConnector,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ Verifies that:
|
||||
3. Upserting is idempotent (running twice doesn't duplicate nodes)
|
||||
4. Document-to-hierarchy-node linkage is updated during pruning
|
||||
5. link_hierarchy_nodes_to_documents links nodes that are also documents
|
||||
6. HierarchyNodeByConnectorCredentialPair join table population and pruning
|
||||
7. Orphaned hierarchy node deletion and re-parenting
|
||||
|
||||
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
|
||||
combined with a real PostgreSQL database for verifying persistence.
|
||||
@@ -26,27 +24,16 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import ensure_source_node_exists
|
||||
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
|
||||
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
@@ -155,80 +142,13 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_cc_pair(
|
||||
db_session: Session,
|
||||
source: DocumentSource = TEST_SOURCE,
|
||||
) -> ConnectorCredentialPair:
|
||||
"""Create a real Connector + Credential + ConnectorCredentialPair for testing."""
|
||||
connector = Connector(
|
||||
name=f"Test {source.value} Connector",
|
||||
source=source,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.flush()
|
||||
|
||||
credential = Credential(
|
||||
source=source,
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=f"Test {source.value} CC Pair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
access_type=AccessType.PUBLIC,
|
||||
)
|
||||
db_session.add(cc_pair)
|
||||
db_session.commit()
|
||||
db_session.refresh(cc_pair)
|
||||
return cc_pair
|
||||
|
||||
|
||||
def _cleanup_test_data(db_session: Session) -> None:
|
||||
"""Remove all test hierarchy nodes and documents to isolate tests."""
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
|
||||
|
||||
test_connector_ids_q = db_session.query(Connector.id).filter(
|
||||
Connector.source == TEST_SOURCE,
|
||||
Connector.name.like("Test %"),
|
||||
)
|
||||
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair).filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == TEST_SOURCE
|
||||
).delete()
|
||||
db_session.flush()
|
||||
|
||||
# Collect credential IDs before deleting cc_pairs (bulk query.delete()
|
||||
# bypasses ORM-level cascade, so credentials won't be auto-removed).
|
||||
credential_ids = [
|
||||
row[0]
|
||||
for row in db_session.query(ConnectorCredentialPair.credential_id)
|
||||
.filter(ConnectorCredentialPair.connector_id.in_(test_connector_ids_q))
|
||||
.all()
|
||||
]
|
||||
|
||||
db_session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Connector).filter(
|
||||
Connector.source == TEST_SOURCE,
|
||||
Connector.name.like("Test %"),
|
||||
).delete(synchronize_session="fetch")
|
||||
if credential_ids:
|
||||
db_session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -259,8 +179,15 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
# raw_id_to_parent should contain ONLY document IDs, not hierarchy node IDs
|
||||
assert result.raw_id_to_parent.keys() == set(SLIM_DOC_IDS)
|
||||
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
|
||||
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
|
||||
expected_ids = {
|
||||
CHANNEL_A_ID,
|
||||
CHANNEL_B_ID,
|
||||
CHANNEL_C_ID,
|
||||
*SLIM_DOC_IDS,
|
||||
}
|
||||
assert result.raw_id_to_parent.keys() == expected_ids
|
||||
|
||||
# Hierarchy nodes should be the 3 channels
|
||||
assert len(result.hierarchy_nodes) == 3
|
||||
@@ -468,9 +395,9 @@ def test_extraction_preserves_parent_hierarchy_raw_node_id(
|
||||
result.raw_id_to_parent[doc_id] == expected_parent
|
||||
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
|
||||
|
||||
# Hierarchy node IDs should NOT be in raw_id_to_parent
|
||||
# Hierarchy node entries have None parent (they aren't documents)
|
||||
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
|
||||
assert channel_id not in result.raw_id_to_parent
|
||||
assert result.raw_id_to_parent[channel_id] is None
|
||||
|
||||
|
||||
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
|
||||
@@ -638,241 +565,3 @@ def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
|
||||
commit=False,
|
||||
)
|
||||
assert linked == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Join table + pruning tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_upsert_hierarchy_node_cc_pair_entries(db_session: Session) -> None:
|
||||
"""upsert_hierarchy_node_cc_pair_entries should insert rows and be idempotent."""
|
||||
_cleanup_test_data(db_session)
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_ids = [n.id for n in upserted]
|
||||
|
||||
# First call — should insert rows
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=node_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
rows = (
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair)
|
||||
.filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id
|
||||
== cc_pair.credential_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(rows) == 3
|
||||
|
||||
# Second call — idempotent, same count
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=node_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
rows_after = (
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair)
|
||||
.filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id
|
||||
== cc_pair.credential_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(rows_after) == 3
|
||||
|
||||
|
||||
def test_remove_stale_entries_and_delete_orphans(db_session: Session) -> None:
|
||||
"""After removing stale join-table entries, orphaned hierarchy nodes should
|
||||
be deleted and the SOURCE node should survive."""
|
||||
_cleanup_test_data(db_session)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
all_ids = [n.id for n in upserted]
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Now simulate a pruning run where only channel A survived
|
||||
channel_a = get_hierarchy_node_by_raw_id(db_session, CHANNEL_A_ID, TEST_SOURCE)
|
||||
assert channel_a is not None
|
||||
live_ids = {channel_a.id}
|
||||
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
live_hierarchy_node_ids=live_ids,
|
||||
commit=True,
|
||||
)
|
||||
assert stale_removed == 2
|
||||
|
||||
# Delete orphaned nodes
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert set(deleted_raw_ids) == {CHANNEL_B_ID, CHANNEL_C_ID}
|
||||
|
||||
# Verify only channel A + SOURCE remain
|
||||
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
|
||||
remaining_raw = {n.raw_node_id for n in remaining}
|
||||
assert remaining_raw == {CHANNEL_A_ID, source_node.raw_node_id}
|
||||
|
||||
|
||||
def test_multi_cc_pair_prevents_premature_deletion(db_session: Session) -> None:
|
||||
"""A hierarchy node shared by two cc_pairs should NOT be deleted when only
|
||||
one cc_pair removes its association."""
|
||||
_cleanup_test_data(db_session)
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair_1 = _create_cc_pair(db_session)
|
||||
cc_pair_2 = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
all_ids = [n.id for n in upserted]
|
||||
|
||||
# cc_pair 1 owns all 3
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair_1.connector_id,
|
||||
credential_id=cc_pair_1.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
# cc_pair 2 also owns all 3
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair_2.connector_id,
|
||||
credential_id=cc_pair_2.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# cc_pair 1 prunes — keeps none
|
||||
remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair_1.connector_id,
|
||||
credential_id=cc_pair_1.credential_id,
|
||||
live_hierarchy_node_ids=set(),
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Orphan deletion should find nothing because cc_pair 2 still references them
|
||||
deleted = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert deleted == []
|
||||
|
||||
# All 3 nodes + SOURCE should still exist
|
||||
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
|
||||
assert len(remaining) == 4
|
||||
|
||||
|
||||
def test_reparent_orphaned_children(db_session: Session) -> None:
|
||||
"""After deleting a parent hierarchy node, its children should be
|
||||
re-parented to the SOURCE node."""
|
||||
_cleanup_test_data(db_session)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
# Create a parent node and a child node
|
||||
parent_node = PydanticHierarchyNode(
|
||||
raw_node_id="PARENT",
|
||||
raw_parent_id=None,
|
||||
display_name="Parent",
|
||||
node_type=HierarchyNodeType.CHANNEL,
|
||||
)
|
||||
child_node = PydanticHierarchyNode(
|
||||
raw_node_id="CHILD",
|
||||
raw_parent_id="PARENT",
|
||||
display_name="Child",
|
||||
node_type=HierarchyNodeType.CHANNEL,
|
||||
)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=[parent_node, child_node],
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
assert len(upserted) == 2
|
||||
|
||||
parent_db = get_hierarchy_node_by_raw_id(db_session, "PARENT", TEST_SOURCE)
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert parent_db is not None and child_db is not None
|
||||
assert child_db.parent_id == parent_db.id
|
||||
|
||||
# Associate only the child with a cc_pair (parent is orphaned)
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[child_db.id],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Delete orphaned nodes (parent has no cc_pair entry)
|
||||
deleted = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert "PARENT" in deleted
|
||||
|
||||
# Child should now have parent_id=NULL (SET NULL cascade)
|
||||
db_session.expire_all()
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert child_db is not None
|
||||
assert child_db.parent_id is None
|
||||
|
||||
# Re-parent orphans to SOURCE
|
||||
reparented = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert len(reparented) == 1
|
||||
|
||||
db_session.expire_all()
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert child_db is not None
|
||||
assert child_db.parent_id == source_node.id
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -15,14 +14,13 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
# In order to get these tests to run, use the credentials from Bitwarden.
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
class DocExternalAccessSet(BaseModel):
|
||||
"""A version of DocExternalAccess that uses sets for comparison."""
|
||||
@@ -54,6 +52,9 @@ def test_jira_doc_sync(
|
||||
This test uses the AS project which has applicationRole permission,
|
||||
meaning all documents should be marked as public.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use AS project specifically for this test
|
||||
connector_config = {
|
||||
@@ -149,6 +150,9 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
This test uses a project that has specific user permissions to verify
|
||||
that specific users are correctly extracted.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use SUP project which has specific user permissions
|
||||
connector_config = {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
@@ -19,8 +18,6 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Expected groups from the danswerai.atlassian.net Jira instance
|
||||
# Note: These groups are shared with Confluence since they're both Atlassian products
|
||||
# App accounts (bots, integrations) are filtered out
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Test that Credential with nested JSON round-trips through SensitiveValue correctly.
|
||||
|
||||
Exercises the full encrypt → store → read → decrypt → SensitiveValue path
|
||||
with realistic nested OAuth credential data, and verifies SQLAlchemy dirty
|
||||
tracking works with nested dict comparison.
|
||||
|
||||
Requires a running Postgres instance.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
# NOTE: this is not the real shape of a Drive credential,
|
||||
# but it is intended to test nested JSON credential handling
|
||||
|
||||
_NESTED_CRED_JSON = {
|
||||
"oauth_tokens": {
|
||||
"access_token": "ya29.abc123",
|
||||
"refresh_token": "1//xEg-def456",
|
||||
},
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"client_config": {
|
||||
"client_id": "123.apps.googleusercontent.com",
|
||||
"client_secret": "GOCSPX-secret",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_nested_credential_json_round_trip(db_session: Session) -> None:
|
||||
"""Nested OAuth credential survives encrypt → store → read → decrypt."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Immediate read (no DB round-trip) — tests the set event wrapping
|
||||
assert isinstance(credential.credential_json, SensitiveValue)
|
||||
assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
# DB round-trip — tests process_result_value
|
||||
db_session.expire(credential)
|
||||
reloaded = credential.credential_json
|
||||
assert isinstance(reloaded, SensitiveValue)
|
||||
assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None:
|
||||
"""Re-assigning the same nested dict should not mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Clear dirty state from the insert
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
# Re-assign identical value
|
||||
credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment]
|
||||
assert not db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_assign_different_nested_json_is_dirty(db_session: Session) -> None:
|
||||
"""Assigning a different nested dict should mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]}
|
||||
credential.credential_json = modified_cred # type: ignore[assignment]
|
||||
assert db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Tests for rotate_encryption_key against real Postgres.
|
||||
|
||||
Uses real ORM models (Credential, InternetSearchProvider) and the actual
|
||||
Postgres database. Discovery is mocked in rotation tests to scope mutations
|
||||
to only the test rows — the real _discover_encrypted_columns walk is tested
|
||||
separately in TestDiscoverEncryptedColumns.
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
|
||||
|
||||
OLD_KEY = "o" * 16
|
||||
NEW_KEY = "n" * 16
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee() -> Generator[None, None, None]:
|
||||
prev = global_version._is_ee
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
global_version._is_ee = prev
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
|
||||
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
|
||||
col = Credential.__table__.c.credential_json
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
Credential.__table__.c.id == credential_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
|
||||
"""Read raw bytes from InternetSearchProvider.api_key."""
|
||||
col = InternetSearchProvider.__table__.c.api_key
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
InternetSearchProvider.__table__.c.id == isp_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
class TestDiscoverEncryptedColumns:
|
||||
"""Verify _discover_encrypted_columns finds real production models."""
|
||||
|
||||
def test_discovers_credential_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("credential", "credential_json", True) in found
|
||||
|
||||
def test_discovers_internet_search_provider_api_key(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("internet_search_provider", "api_key", False) in found
|
||||
|
||||
def test_all_encrypted_string_columns_are_not_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedString):
|
||||
assert not is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
def test_all_encrypted_json_columns_are_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
assert is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
|
||||
class TestRotateCredential:
|
||||
"""Test rotation against the real Credential table (EncryptedJson).
|
||||
|
||||
Discovery is scoped to only the Credential model to avoid mutating
|
||||
other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[(Credential, "credential_json", ["id"], True)],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def credential_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert a Credential row with raw encrypted bytes, clean up after."""
|
||||
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
|
||||
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO credential "
|
||||
"(source, credential_json, admin_public, curator_public) "
|
||||
"VALUES (:source, :cred_json, true, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
|
||||
)
|
||||
cred_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield cred_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_credential_json(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
assert decrypted["endpoint"] == "https://example.com"
|
||||
|
||||
def test_skips_already_rotated(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
|
||||
def test_dry_run_does_not_modify(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
original = _raw_credential_bytes(db_session, credential_id)
|
||||
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw_after = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw_after == original
|
||||
|
||||
|
||||
class TestRotateInternetSearchProvider:
|
||||
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
|
||||
|
||||
Discovery is scoped to only the InternetSearchProvider model to avoid
|
||||
mutating other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[
|
||||
(InternetSearchProvider, "api_key", ["id"], False),
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def isp_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
|
||||
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-rotation-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": encrypted,
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield isp_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
|
||||
|
||||
def test_rotates_from_unencrypted(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> None:
|
||||
"""Test rotating data that was stored without any encryption key."""
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-raw-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": b"raw-api-key",
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=None)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -85,7 +85,7 @@ def test_group_overlap_filter(
|
||||
results = _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
user_email="",
|
||||
user_email=None,
|
||||
external_group_ids=["group_engineering"],
|
||||
)
|
||||
result_ids = {n.raw_node_id for n in results}
|
||||
@@ -124,7 +124,7 @@ def test_no_credentials_returns_only_public(
|
||||
results = _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
user_email="",
|
||||
user_email=None,
|
||||
external_group_ids=[],
|
||||
)
|
||||
result_ids = {n.raw_node_id for n in results}
|
||||
|
||||
@@ -158,7 +158,7 @@ class TestLLMConfigurationEndpoint:
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.detail == error_message
|
||||
assert exc_info.value.message == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -540,7 +540,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -585,7 +585,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.detail == error_message
|
||||
assert exc_info.value.message == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -247,7 +247,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -350,7 +350,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -386,7 +386,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1152,179 +1152,3 @@ class TestAutoModeTransitionsAndResync:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_updates_default_when_recommended_default_changes(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and a sync arrives with a
|
||||
different recommended default model (both models still in config),
|
||||
the global default should be updated to the new recommendation.
|
||||
|
||||
Steps:
|
||||
1. Create auto-mode provider with config v1: default=gpt-4o.
|
||||
2. Set gpt-4o as the global CHAT default.
|
||||
3. Re-sync with config v2: default=gpt-4o-mini (gpt-4o still present).
|
||||
4. Verify the CHAT default switched to gpt-4o-mini and both models
|
||||
remain visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o as the global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Re-sync with config v2 (recommended default changed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0, "Sync should report changes when default switches"
|
||||
|
||||
# Both models should remain visible
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is True
|
||||
assert visibility["gpt-4o-mini"] is True
|
||||
|
||||
# The CHAT default should now be gpt-4o-mini
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None
|
||||
assert (
|
||||
default_after.name == "gpt-4o-mini"
|
||||
), f"Default should be updated to 'gpt-4o-mini', got '{default_after.name}'"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_idempotent_when_default_already_matches(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and it already matches the
|
||||
recommended default, re-syncing should report zero changes.
|
||||
|
||||
This is a regression test for the bug where changes was unconditionally
|
||||
incremented even when the default was already correct.
|
||||
"""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o (the recommended default) as global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# First sync to stabilize state
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Second sync — default already matches, should be a no-op
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert changes == 0, (
|
||||
f"Expected 0 changes when default already matches recommended, "
|
||||
f"got {changes}"
|
||||
)
|
||||
|
||||
# Default should still be gpt-4o
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == "gpt-4o"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
This should act as the main point of reference for testing that default model
|
||||
logic is consisten.
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
|
||||
def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
models: list[ModelConfigurationUpsertRequest] | None = None,
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider with multiple models."""
|
||||
if models is None:
|
||||
models = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True, supports_image_input=False
|
||||
),
|
||||
]
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=models,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_provider(db_session: Session, name: str) -> None:
|
||||
"""Helper to clean up a test provider by name."""
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if provider:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_name(db_session: Session) -> Generator[str, None, None]:
|
||||
"""Generate a unique provider name for each test, with automatic cleanup."""
|
||||
name = f"test-provider-{uuid4().hex[:8]}"
|
||||
yield name
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, name)
|
||||
|
||||
|
||||
class TestDefaultModelProtection:
|
||||
"""Tests that the default model cannot be removed or hidden."""
|
||||
|
||||
def test_cannot_remove_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default text model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to update the provider without the default model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_hide_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Setting is_visible=False on the default text model should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to hide the default model
|
||||
with pytest.raises(ValueError, match="Cannot hide the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=False
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_remove_default_vision_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default vision model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
# Set gpt-4o as both the text and vision default
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
update_default_vision_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to remove the default vision model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_can_remove_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Remove gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_names = {mc.name for mc in updated.model_configurations}
|
||||
assert "gpt-4o" in model_names
|
||||
assert "gpt-4o-mini" not in model_names
|
||||
|
||||
def test_can_hide_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Hiding a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Hide gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=False
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in updated.model_configurations
|
||||
}
|
||||
assert model_visibility["gpt-4o"] is True
|
||||
assert model_visibility["gpt-4o-mini"] is False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user