mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-21 01:16:45 +00:00
Compare commits
78 Commits
cli/v0.1.0
...
v3.0.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9632dc5746 | ||
|
|
216c9841a7 | ||
|
|
ab99945c9b | ||
|
|
b61109a747 | ||
|
|
78459fb3e7 | ||
|
|
e243d7955b | ||
|
|
77f5411bf7 | ||
|
|
c45caf1f1d | ||
|
|
4f534249d6 | ||
|
|
eb87d88b89 | ||
|
|
4fd6786ce2 | ||
|
|
6919afe022 | ||
|
|
c4ac0fd286 | ||
|
|
d2f8e38e67 | ||
|
|
bbd57c5904 | ||
|
|
546d5cd384 | ||
|
|
f902f49483 | ||
|
|
ed3630e248 | ||
|
|
598e605dd2 | ||
|
|
aee02f6501 | ||
|
|
2959470114 | ||
|
|
7d9a339e0b | ||
|
|
a2742fcabf | ||
|
|
ba4b4f0930 | ||
|
|
74a4d620ad | ||
|
|
51f46bd8f0 | ||
|
|
e6cfe77a6d | ||
|
|
cc3719f356 | ||
|
|
b658ad8985 | ||
|
|
b1632044ed | ||
|
|
9fa8265f00 | ||
|
|
ce53e123dc | ||
|
|
5606ae5e81 | ||
|
|
923e0691aa | ||
|
|
b232e2a771 | ||
|
|
c3ebfeda2f | ||
|
|
6a28dfedb1 | ||
|
|
a123ec083d | ||
|
|
f448f1274d | ||
|
|
d12f8b94aa | ||
|
|
355fe2ff2c | ||
|
|
8ec5423a0c | ||
|
|
79b615db46 | ||
|
|
98756bccd4 | ||
|
|
418f84ccdf | ||
|
|
d37756a884 | ||
|
|
9cdc92441b | ||
|
|
b8ed30644a | ||
|
|
d7d19e5a28 | ||
|
|
948650829d | ||
|
|
b6e689be0f | ||
|
|
85877408c8 | ||
|
|
c00df75c79 | ||
|
|
6352c9a09e | ||
|
|
3065f70d7d | ||
|
|
4befbc49dc | ||
|
|
ae9679e8c4 | ||
|
|
ea0ddee5c8 | ||
|
|
2826405dd2 | ||
|
|
8485bf4368 | ||
|
|
7bb52b0839 | ||
|
|
85a54c01f1 | ||
|
|
e4577bd564 | ||
|
|
f150a7b940 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 |
@@ -106,13 +106,34 @@ onyx-cli ask --json "What authentication methods do we support?"
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation with `citation_number` and `document_id` |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
@@ -129,6 +150,10 @@ Uses a specific Onyx agent/persona instead of the default.
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
64
.github/workflows/deployment.yml
vendored
64
.github/workflows/deployment.yml
vendored
@@ -29,20 +29,32 @@ jobs:
|
||||
build-backend-craft: ${{ steps.check.outputs.build-backend-craft }}
|
||||
build-model-server: ${{ steps.check.outputs.build-model-server }}
|
||||
is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }}
|
||||
is-stable: ${{ steps.check.outputs.is-stable }}
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-craft-latest: ${{ steps.check.outputs.is-craft-latest }}
|
||||
is-latest: ${{ steps.check.outputs.is-latest }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Checkout (for git tags)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -eo pipefail
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
@@ -54,9 +66,8 @@ jobs:
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_CRAFT_LATEST=false
|
||||
IS_LATEST=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
@@ -67,9 +78,6 @@ jobs:
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == craft-* ]]; then
|
||||
IS_CRAFT_LATEST=true
|
||||
fi
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
@@ -97,20 +105,28 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
# Craft-latest builds backend with Craft enabled
|
||||
if [[ "$IS_CRAFT_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
BUILD_BACKEND=false
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this tag should get the "latest" Docker tag.
|
||||
# Only the highest semver stable tag (vX.Y.Z exactly) gets "latest".
|
||||
if [[ "$IS_STABLE" == "true" ]]; then
|
||||
HIGHEST_STABLE=$(uv run --no-sync --with onyx-devtools ods latest-stable-tag) || {
|
||||
echo "::error::Failed to determine highest stable tag via 'ods latest-stable-tag'"
|
||||
exit 1
|
||||
}
|
||||
if [[ "$TAG" == "$HIGHEST_STABLE" ]]; then
|
||||
IS_LATEST=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Build craft-latest backend alongside the regular latest.
|
||||
if [[ "$IS_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
@@ -129,11 +145,9 @@ jobs:
|
||||
echo "build-backend-craft=$BUILD_BACKEND_CRAFT"
|
||||
echo "build-model-server=$BUILD_MODEL_SERVER"
|
||||
echo "is-cloud-tag=$IS_CLOUD"
|
||||
echo "is-stable=$IS_STABLE"
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-craft-latest=$IS_CRAFT_LATEST"
|
||||
echo "is-latest=$IS_LATEST"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
@@ -600,7 +614,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1037,7 +1052,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1249,8 +1264,6 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
@@ -1473,7 +1486,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -418,6 +419,7 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -637,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -691,6 +694,7 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
|
||||
3
.github/workflows/pr-playwright-tests.yml
vendored
3
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
# TODO: Remove this if we enable merge-queues for release branches.
|
||||
branches:
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
178
.github/workflows/release-cli.yml
vendored
178
.github/workflows/release-cli.yml
vendored
@@ -26,8 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -38,3 +37,178 @@ jobs:
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -22,12 +22,10 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
@@ -544,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -596,7 +598,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
|
||||
@@ -46,7 +46,11 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -164,6 +168,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -117,6 +123,26 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +127,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -125,10 +126,16 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
@@ -88,6 +90,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a delete_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
@@ -546,7 +559,23 @@ def process_single_user_file(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks."""
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still DELETING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
"""
|
||||
task_logger.info("check_for_user_file_delete - Starting")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
.all()
|
||||
)
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_delete_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -602,6 +663,9 @@ def delete_user_file_impl(
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Clear the queued guard so the beat can re-enqueue if deletion fails
|
||||
# and the file remains in DELETING status.
|
||||
redis_client.delete(_user_file_delete_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
|
||||
@@ -50,6 +50,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
@@ -980,6 +981,10 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
|
||||
@@ -288,8 +288,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
@@ -787,6 +788,29 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Maximum embedded images allowed in a single file. PDFs (and other formats)
|
||||
# with thousands of embedded images can OOM the user-file-processing worker
|
||||
# because every image is decoded with PIL and then sent to the vision LLM.
|
||||
# Enforced both at upload time (rejects the file) and during extraction
|
||||
# (defense-in-depth: caps the number of images materialized).
|
||||
#
|
||||
# Clamped to >= 0; a negative env value would turn upload validation into
|
||||
# always-fail and extraction into always-stop, which is never desired. 0
|
||||
# disables image extraction entirely, which is a valid (if aggressive) setting.
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
|
||||
)
|
||||
|
||||
# Maximum embedded images allowed across all files in a single upload batch.
|
||||
# Protects against the scenario where a user uploads many files that each
|
||||
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
|
||||
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
|
||||
# worker under concurrency or run up surprise latency/cost. Also clamped
|
||||
# to >= 0.
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file-delete task is valid before workers discard it.
|
||||
# Mirrors the processing task expiry to prevent indefinite queue growth when
|
||||
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before the delete beat stops enqueuing more delete tasks.
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
|
||||
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -204,7 +205,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -227,22 +228,23 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,6 +1722,7 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,6 +476,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -484,6 +485,8 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -492,6 +495,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -507,6 +511,7 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -672,6 +677,7 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -753,6 +759,7 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -774,6 +781,7 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
|
||||
@@ -157,9 +157,7 @@ def _execute_single_retrieval(
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
results = _execute_with_retry(retrieval_function(**request_kwargs))
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
|
||||
@@ -33,6 +33,7 @@ from office365.runtime.queries.client_query import ClientQuery # type: ignore[i
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
@@ -268,6 +269,32 @@ class SizeCapExceeded(Exception):
|
||||
"""Exception raised when the size cap is exceeded."""
|
||||
|
||||
|
||||
def _log_and_raise_for_status(response: requests.Response) -> None:
|
||||
"""Log the response text and raise for status."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
logger.error(f"HTTP request failed: {response.text}")
|
||||
raise
|
||||
|
||||
|
||||
GRAPH_INVALID_REQUEST_CODE = "invalidRequest"
|
||||
|
||||
|
||||
def _is_graph_invalid_request(response: requests.Response) -> bool:
|
||||
"""Return True if the response body is the generic Graph API
|
||||
``{"error": {"code": "invalidRequest", "message": "Invalid request"}}``
|
||||
shape. This particular error has no actionable inner error code and is
|
||||
returned by the site-pages endpoint when a page has a corrupt canvas layout
|
||||
(e.g. duplicate web-part IDs — see SharePoint/sp-dev-docs#8822)."""
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception:
|
||||
return False
|
||||
error = body.get("error", {})
|
||||
return error.get("code") == GRAPH_INVALID_REQUEST_CODE
|
||||
|
||||
|
||||
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
|
||||
"""Load certificate from .pfx file for MSAL authentication"""
|
||||
try:
|
||||
@@ -344,7 +371,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
"""Determine remote size using HEAD or a range GET probe. Returns None if unknown."""
|
||||
try:
|
||||
head_resp = requests.head(url, timeout=timeout, allow_redirects=True)
|
||||
head_resp.raise_for_status()
|
||||
_log_and_raise_for_status(head_resp)
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
return int(cl)
|
||||
@@ -359,7 +386,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
timeout=timeout,
|
||||
stream=True,
|
||||
) as range_resp:
|
||||
range_resp.raise_for_status()
|
||||
_log_and_raise_for_status(range_resp)
|
||||
cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345"
|
||||
if cr and "/" in cr:
|
||||
total = cr.split("/")[-1]
|
||||
@@ -384,7 +411,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
- Returns the full bytes if the content fits within `cap`.
|
||||
"""
|
||||
with requests.get(url, stream=True, timeout=timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
_log_and_raise_for_status(resp)
|
||||
|
||||
# If the server provides Content-Length, prefer an early decision.
|
||||
cl_header = resp.headers.get("Content-Length")
|
||||
@@ -428,7 +455,7 @@ def _download_via_graph_api(
|
||||
with requests.get(
|
||||
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
_log_and_raise_for_status(resp)
|
||||
buf = io.BytesIO()
|
||||
for chunk in resp.iter_content(64 * 1024):
|
||||
if not chunk:
|
||||
@@ -1238,26 +1265,135 @@ class SharepointConnector(
|
||||
site.execute_query()
|
||||
site_id = site.id
|
||||
|
||||
page_url: str | None = (
|
||||
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
|
||||
site_pages_base = (
|
||||
f"{self.graph_api_base}/sites/{site_id}/pages/microsoft.graph.sitePage"
|
||||
)
|
||||
page_url: str | None = site_pages_base
|
||||
params: dict[str, str] | None = {"$expand": "canvasLayout"}
|
||||
total_yielded = 0
|
||||
yielded_ids: set[str] = set()
|
||||
|
||||
while page_url:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
logger.warning(f"Site page not found: {page_url}")
|
||||
break
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout on the LIST endpoint returned 400 "
|
||||
f"for site {site_descriptor.url}. Falling back to "
|
||||
f"per-page expansion."
|
||||
)
|
||||
yield from self._fetch_site_pages_individually(
|
||||
site_pages_base, start, end, skip_ids=yielded_ids
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
params = None # nextLink already embeds query params
|
||||
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
total_yielded += 1
|
||||
page_id = page.get("id")
|
||||
if page_id:
|
||||
yielded_ids.add(page_id)
|
||||
yield page
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
|
||||
|
||||
def _fetch_site_pages_individually(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
skip_ids: set[str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Fallback for _fetch_site_pages: list pages without $expand, then
|
||||
expand canvasLayout on each page individually.
|
||||
|
||||
The Graph API's LIST endpoint can return 400 when $expand=canvasLayout
|
||||
is used and *any* page in the site has a corrupt canvas layout (e.g.
|
||||
duplicate web part IDs — see SharePoint/sp-dev-docs#8822). Since the
|
||||
LIST expansion is all-or-nothing, a single bad page poisons the entire
|
||||
response. This method works around it by fetching metadata first, then
|
||||
expanding each page individually so only the broken page loses its
|
||||
canvas content.
|
||||
|
||||
``skip_ids`` contains page IDs already yielded by the caller before the
|
||||
fallback was triggered, preventing duplicates.
|
||||
"""
|
||||
page_url: str | None = site_pages_base
|
||||
total_yielded = 0
|
||||
_skip_ids = skip_ids or set()
|
||||
|
||||
while page_url:
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
break
|
||||
raise
|
||||
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
|
||||
page_id = page.get("id")
|
||||
if page_id and page_id in _skip_ids:
|
||||
continue
|
||||
|
||||
if not page_id:
|
||||
total_yielded += 1
|
||||
yield page
|
||||
continue
|
||||
|
||||
expanded = self._try_expand_single_page(site_pages_base, page_id, page)
|
||||
total_yielded += 1
|
||||
yield expanded
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(
|
||||
f"Yielded {total_yielded} site pages (per-page expansion fallback)"
|
||||
)
|
||||
|
||||
def _try_expand_single_page(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
page_id: str,
|
||||
fallback_page: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Try to GET a single page with $expand=canvasLayout. On 400, return
|
||||
the metadata-only fallback so the page is still indexed (without canvas
|
||||
content)."""
|
||||
pages_collection = site_pages_base.removesuffix("/microsoft.graph.sitePage")
|
||||
single_url = f"{pages_collection}/{page_id}/microsoft.graph.sitePage"
|
||||
try:
|
||||
return self._graph_api_get_json(single_url, {"$expand": "canvasLayout"})
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
page_name = fallback_page.get("name", page_id)
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout failed for page '{page_name}' "
|
||||
f"({page_id}). Indexing metadata only."
|
||||
)
|
||||
return fallback_page
|
||||
raise
|
||||
|
||||
def _acquire_token(self) -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
@@ -1309,7 +1445,7 @@ class SharepointConnector(
|
||||
access_token = self._get_graph_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
continue
|
||||
response.raise_for_status()
|
||||
_log_and_raise_for_status(response)
|
||||
return response.json()
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
|
||||
@@ -583,6 +583,67 @@ def get_latest_index_attempt_for_cc_pair_id(
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool = False,
|
||||
) -> IndexAttempt | None:
|
||||
"""Returns the most recent successful index attempt for the given cc pair,
|
||||
filtered to the current (or future) search settings.
|
||||
Uses MAX(id) semantics to match get_latest_index_attempts_by_status."""
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.join(SearchSettings)
|
||||
.where(SearchSettings.status == status)
|
||||
.order_by(desc(IndexAttempt.id))
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempts_parallel(
|
||||
secondary_index: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
"""Batch version: returns the latest successful index attempt per cc pair.
|
||||
Covers both SUCCESS and COMPLETED_WITH_ERRORS (matching is_successful())."""
|
||||
model_status = (
|
||||
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
latest_ids = (
|
||||
select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.where(
|
||||
SearchSettings.status == model_status,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(IndexAttempt).join(
|
||||
latest_ids,
|
||||
(
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== latest_ids.c.connector_credential_pair_id
|
||||
)
|
||||
& (IndexAttempt.id == latest_ids.c.max_id),
|
||||
)
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def count_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -270,10 +271,35 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if default_model.id in removed_ids:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -344,9 +370,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
models: list[SyncModelEntry],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -354,7 +380,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -368,21 +394,20 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
if model.name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
if model.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
model_name=model.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
@@ -538,7 +563,6 @@ def fetch_default_model(
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
@@ -814,44 +838,30 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
# Update the default if this provider currently holds the global CHAT default.
|
||||
# We flush (but don't commit) so that _update_default_model can see the new
|
||||
# model rows, then commit everything atomically to avoid a window where the
|
||||
# old default is invisible but still pointed-to.
|
||||
db_session.flush()
|
||||
|
||||
# Update the default if this provider currently holds the global CHAT default
|
||||
recommended_default = llm_recommendations.get_default_model(provider.provider)
|
||||
if recommended_default:
|
||||
current_default_name = db_session.scalar(
|
||||
select(ModelConfiguration.name)
|
||||
.join(
|
||||
LLMModelFlow,
|
||||
LLMModelFlow.model_configuration_id == ModelConfiguration.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider.id,
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
|
||||
if (
|
||||
current_default_name is not None
|
||||
and current_default_name != recommended_default.name
|
||||
current_default
|
||||
and current_default.llm_provider_id == provider.id
|
||||
and current_default.name != recommended_default.name
|
||||
):
|
||||
try:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Recommended default model '%s' not found "
|
||||
"for provider_id=%s; skipping default update.",
|
||||
recommended_default.name,
|
||||
provider.id,
|
||||
)
|
||||
_update_default_model__no_commit(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=recommended_default.name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -982,7 +992,7 @@ def update_model_configuration__no_commit(
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
def _update_default_model__no_commit(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
@@ -1020,6 +1030,14 @@ def _update_default_model(
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
_update_default_model__no_commit(db_session, provider_id, model, flow_type)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -503,20 +503,31 @@ def query_vespa(
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = "Failed to query Vespa"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
+ (
|
||||
f"\nResponse: {e.response.text}"
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
else ""
|
||||
)
|
||||
response_text = (
|
||||
e.response.text if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
status_code = (
|
||||
e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
yql_value = params.get("yql", "")
|
||||
yql_length = len(str(yql_value))
|
||||
|
||||
# Log each detail on its own line so log collectors capture them
|
||||
# as separate entries rather than truncating a single multiline msg
|
||||
logger.error(
|
||||
f"Failed to query Vespa | "
|
||||
f"status={status_code} | "
|
||||
f"yql_length={yql_length} | "
|
||||
f"exception={str(e)}"
|
||||
)
|
||||
if response_text:
|
||||
logger.error(f"Vespa error response: {response_text[:1000]}")
|
||||
logger.error(f"Vespa request URL: {e.request.url}")
|
||||
|
||||
# Re-raise with diagnostics so callers see what actually went wrong
|
||||
raise httpx.HTTPError(
|
||||
f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})"
|
||||
) from e
|
||||
|
||||
response_json: dict[str, Any] = response.json()
|
||||
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,38 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str:
|
||||
"""Build a Vespa weightedSet filter for large value lists.
|
||||
|
||||
Uses Vespa's native weightedSet() operator instead of OR-chained
|
||||
'contains' clauses. This is critical for fields like
|
||||
access_control_list where a single user may have tens of thousands
|
||||
of ACL entries — OR clauses at that scale cause Vespa to reject
|
||||
the query with HTTP 400."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
filtered = [val for val in vals if val]
|
||||
if not filtered:
|
||||
return ""
|
||||
items = ", ".join(f'"{val}":1' for val in filtered)
|
||||
return f"weightedSet({key}, {{{items}}})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +78,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +105,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +114,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +133,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +145,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +157,99 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
# ACL filters — use weightedSet for efficient matching against the
|
||||
# access_control_list weightedset<string> field. OR-chaining thousands
|
||||
# of 'contains' clauses causes Vespa to reject the query (HTTP 400)
|
||||
# for users with large numbers of external permission groups.
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_weighted_set_filter(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -91,11 +91,11 @@ class OnyxErrorCode(Enum):
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
{"error_code": "UNAUTHENTICATED", "detail": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
If no message is supplied, the error code itself is used as the detail.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
"detail": message or self.code,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
``{"error_code": "...", "detail": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -37,21 +37,21 @@ class OnyxError(Exception):
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
detail: Human-readable detail (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
detail: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
resolved_message = message or error_code.code
|
||||
super().__init__(resolved_message)
|
||||
resolved_detail = detail or error_code.code
|
||||
super().__init__(resolved_detail)
|
||||
self.error_code = error_code
|
||||
self.message = resolved_message
|
||||
self.detail = resolved_detail
|
||||
self._status_code_override = status_code_override
|
||||
|
||||
@property
|
||||
@@ -73,11 +73,11 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import chardet
|
||||
import openpyxl
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -176,6 +177,56 @@ def read_text_file(
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
|
||||
|
||||
Used to reject PDFs whose image count would OOM the user-file-processing
|
||||
worker during indexing. Returns a value > cap as a sentinel once the count
|
||||
exceeds the cap, so callers do not iterate thousands of image objects just
|
||||
to report a number. Returns 0 if the PDF cannot be parsed.
|
||||
|
||||
Owner-password-only PDFs (permission restrictions but no open password) are
|
||||
counted normally — they decrypt with an empty string. Truly password-locked
|
||||
PDFs are skipped (return 0) since we can't inspect them; the caller should
|
||||
ensure the password-protected check runs first.
|
||||
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
reader = PdfReader(file)
|
||||
if reader.is_encrypted:
|
||||
# Try empty password first (owner-password-only PDFs); give up if that fails.
|
||||
try:
|
||||
if reader.decrypt("") == 0:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
count = 0
|
||||
for page in reader.pages:
|
||||
for _ in page.images:
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in PDF", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
@@ -231,8 +282,27 @@ def read_pdf_file(
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
images_processed = 0
|
||||
cap_reached = False
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
if cap_reached:
|
||||
break
|
||||
for image_file_object in page.images:
|
||||
if images_processed >= image_cap:
|
||||
# Defense-in-depth backstop. Upload-time validation
|
||||
# should have rejected files exceeding the cap, but
|
||||
# we also break here so a single oversized file can
|
||||
# never pin a worker.
|
||||
logger.warning(
|
||||
"PDF embedded image cap reached (%d). "
|
||||
"Skipping remaining images on page %d and beyond.",
|
||||
image_cap,
|
||||
page_num + 1,
|
||||
)
|
||||
cap_reached = True
|
||||
break
|
||||
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
@@ -245,6 +315,7 @@ def read_pdf_file(
|
||||
image_callback(img_bytes, image_name)
|
||||
else:
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
images_processed += 1
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
|
||||
@@ -19,12 +19,16 @@ class OnyxMimeTypes:
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-log",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
"text/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
|
||||
@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
|
||||
try:
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
except UnsupportedImageFormatError:
|
||||
magic_hex = image_data[:8].hex() if image_data else "empty"
|
||||
logger.info(
|
||||
"Skipping image summarization due to unsupported MIME type for %s",
|
||||
"Skipping image summarization due to unsupported MIME type "
|
||||
"for %s (magic_bytes=%s, size=%d bytes)",
|
||||
context_name,
|
||||
magic_hex,
|
||||
len(image_data),
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -134,9 +138,23 @@ def _summarize_image(
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Summarization failed. Messages: {messages}"
|
||||
error_msg = error_msg[:1024]
|
||||
raise ValueError(error_msg) from e
|
||||
# Extract structured details from LiteLLM exceptions when available,
|
||||
# rather than dumping the full messages payload (which contains base64
|
||||
# image data and produces enormous, unreadable error logs).
|
||||
str_e = str(e)
|
||||
if len(str_e) > 512:
|
||||
str_e = str_e[:512] + "... (truncated)"
|
||||
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
|
||||
status_code = getattr(e, "status_code", None)
|
||||
llm_provider = getattr(e, "llm_provider", None)
|
||||
model = getattr(e, "model", None)
|
||||
if status_code is not None:
|
||||
parts.append(f"status_code={status_code}")
|
||||
if llm_provider is not None:
|
||||
parts.append(f"llm_provider={llm_provider}")
|
||||
if model is not None:
|
||||
parts.append(f"model={model}")
|
||||
raise ValueError(" | ".join(parts)) from e
|
||||
|
||||
|
||||
def _encode_image_for_llm_prompt(image_data: bytes) -> str:
|
||||
|
||||
@@ -43,6 +43,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
]
|
||||
|
||||
|
||||
@@ -59,6 +60,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -109,6 +111,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -1512,6 +1512,14 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-7": {
|
||||
"display_name": "Claude Opus 4.7",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1526,6 +1534,10 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -46,6 +46,15 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.HIGH: 4096,
|
||||
}
|
||||
|
||||
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
|
||||
# output_config.effort instead of thinking.type.enabled + budget_tokens.
|
||||
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.AUTO: "medium",
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
@@ -67,8 +68,13 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
)
|
||||
|
||||
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
|
||||
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
|
||||
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
@@ -193,6 +199,29 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
adaptive_model in normalized_model_name
|
||||
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -445,22 +474,42 @@ class LitellmLLM(LLM):
|
||||
}
|
||||
|
||||
elif is_claude_model:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
|
||||
|
||||
if budget_tokens is not None:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
if _anthropic_uses_adaptive_thinking(self.config.model_name):
|
||||
# Newer Anthropic models (Claude Opus 4.7+) reject
|
||||
# thinking.type.enabled — they require the adaptive
|
||||
# thinking config with output_config.effort.
|
||||
if not has_tool_call_history:
|
||||
optional_kwargs["thinking"] = {"type": "adaptive"}
|
||||
optional_kwargs["output_config"] = {
|
||||
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
else:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
if budget_tokens is not None and not has_tool_call_history:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
|
||||
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
|
||||
optional_kwargs.pop("reasoning_effort", None)
|
||||
|
||||
@@ -1,37 +1,8 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
@@ -40,6 +11,8 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
@@ -51,13 +24,6 @@ OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -65,13 +31,6 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
@@ -47,6 +48,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
}
|
||||
|
||||
|
||||
@@ -331,6 +333,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"version": "1.2",
|
||||
"updated_at": "2026-04-16T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
@@ -10,12 +10,20 @@
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
"default_model": "claude-opus-4-6",
|
||||
"default_model": "claude-opus-4-7",
|
||||
"additional_visible_models": [
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"display_name": "Claude Opus 4.7"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -26,6 +27,14 @@ async def search_indexed_documents(
|
||||
Use this tool for information that is not public knowledge and specific to the user,
|
||||
their team, their work, or their organization/company.
|
||||
|
||||
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
|
||||
on every call, consuming tokens and adding latency.
|
||||
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
|
||||
but this should still be sufficient for most use cases. CE mode functionality should be swapped
|
||||
when a dedicated CE search endpoint is implemented.
|
||||
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
@@ -111,48 +120,73 @@ async def search_indexed_documents(
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
endpoint,
|
||||
json=search_request,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get("error"):
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get("error"),
|
||||
"error": result.get(error_key),
|
||||
}
|
||||
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
"content",
|
||||
"source_type",
|
||||
"link",
|
||||
"score",
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("search_docs", [])
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
# backend retrieves fewer documents than requested.
|
||||
documents = documents[:limit]
|
||||
|
||||
logger.info(
|
||||
f"Onyx MCP Server: Internal search returned {len(documents)} results"
|
||||
)
|
||||
@@ -160,7 +194,6 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -41,6 +44,51 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -157,6 +205,20 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -207,6 +269,7 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -231,6 +294,16 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -285,6 +358,7 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -43,6 +43,9 @@ from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import (
|
||||
get_latest_successful_index_attempt_for_cc_pair_id,
|
||||
)
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -190,6 +193,11 @@ def get_cc_pair_full_info(
|
||||
only_finished=False,
|
||||
)
|
||||
|
||||
latest_successful_attempt = get_latest_successful_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# Get latest permission sync attempt for status
|
||||
latest_permission_sync_attempt = None
|
||||
if cc_pair.access_type == AccessType.SYNC:
|
||||
@@ -207,6 +215,11 @@ def get_cc_pair_full_info(
|
||||
cc_pair_id=cc_pair_id,
|
||||
),
|
||||
last_index_attempt=latest_attempt,
|
||||
last_successful_index_time=(
|
||||
latest_successful_attempt.time_started
|
||||
if latest_successful_attempt
|
||||
else None
|
||||
),
|
||||
latest_deletion_attempt=get_deletion_attempt_snapshot(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
|
||||
@@ -3,6 +3,7 @@ import math
|
||||
import mimetypes
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -109,6 +110,9 @@ from onyx.db.federated import fetch_all_federated_connectors_parallel
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_parallel
|
||||
from onyx.db.index_attempt import (
|
||||
get_latest_successful_index_attempts_parallel,
|
||||
)
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import FederatedConnector
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -479,7 +483,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
|
||||
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
files: list[UploadFile],
|
||||
file_origin: FileOrigin = FileOrigin.CONNECTOR,
|
||||
unzip: bool = True,
|
||||
) -> FileUploadResponse:
|
||||
|
||||
# Skip directories and known macOS metadata entries
|
||||
@@ -502,31 +508,46 @@ def upload_files(
|
||||
if seen_zip:
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
seen_zip = True
|
||||
|
||||
# Validate the zip by opening it (catches corrupt/non-zip files)
|
||||
with zipfile.ZipFile(file.file, "r") as zf:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
if unzip:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
continue
|
||||
|
||||
# Store the zip as-is (unzip=False)
|
||||
file.file.seek(0)
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type=file.content_type or "application/zip",
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
# Since we can't render docx files in the UI,
|
||||
@@ -613,9 +634,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
unzip: bool = True,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
|
||||
|
||||
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
@@ -1140,21 +1162,26 @@ def get_connector_indexing_status(
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get most recent successful index attempts
|
||||
(
|
||||
lambda: get_latest_successful_index_attempts_parallel(
|
||||
request.secondary_index,
|
||||
),
|
||||
(),
|
||||
),
|
||||
]
|
||||
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
# For Admin users, we already got all the cc pair in editable_cc_pairs
|
||||
# its not needed to get them again
|
||||
(
|
||||
editable_cc_pairs,
|
||||
federated_connectors,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
latest_successful_index_attempts,
|
||||
) = run_functions_tuples_in_parallel(parallel_functions)
|
||||
non_editable_cc_pairs = []
|
||||
else:
|
||||
parallel_functions.append(
|
||||
# Get non-editable connector/credential pairs
|
||||
(
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, False, None, True, True, False, True, request.source
|
||||
@@ -1168,6 +1195,7 @@ def get_connector_indexing_status(
|
||||
federated_connectors,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
latest_successful_index_attempts,
|
||||
non_editable_cc_pairs,
|
||||
) = run_functions_tuples_in_parallel(parallel_functions)
|
||||
|
||||
@@ -1179,6 +1207,9 @@ def get_connector_indexing_status(
|
||||
latest_finished_index_attempts = cast(
|
||||
list[IndexAttempt], latest_finished_index_attempts
|
||||
)
|
||||
latest_successful_index_attempts = cast(
|
||||
list[IndexAttempt], latest_successful_index_attempts
|
||||
)
|
||||
|
||||
document_count_info = get_document_counts_for_all_cc_pairs(db_session)
|
||||
|
||||
@@ -1188,42 +1219,48 @@ def get_connector_indexing_status(
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
|
||||
cc_pair_to_latest_index_attempt: dict[tuple[int, int], IndexAttempt] = {
|
||||
(
|
||||
attempt.connector_credential_pair.connector_id,
|
||||
attempt.connector_credential_pair.credential_id,
|
||||
): attempt
|
||||
for attempt in latest_index_attempts
|
||||
}
|
||||
def _attempt_lookup(
|
||||
attempts: list[IndexAttempt],
|
||||
) -> dict[int, IndexAttempt]:
|
||||
return {attempt.connector_credential_pair_id: attempt for attempt in attempts}
|
||||
|
||||
cc_pair_to_latest_finished_index_attempt: dict[tuple[int, int], IndexAttempt] = {
|
||||
(
|
||||
attempt.connector_credential_pair.connector_id,
|
||||
attempt.connector_credential_pair.credential_id,
|
||||
): attempt
|
||||
for attempt in latest_finished_index_attempts
|
||||
}
|
||||
cc_pair_to_latest_index_attempt = _attempt_lookup(latest_index_attempts)
|
||||
cc_pair_to_latest_finished_index_attempt = _attempt_lookup(
|
||||
latest_finished_index_attempts
|
||||
)
|
||||
cc_pair_to_latest_successful_index_attempt = _attempt_lookup(
|
||||
latest_successful_index_attempts
|
||||
)
|
||||
|
||||
def build_connector_indexing_status(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
is_editable: bool,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
return None
|
||||
|
||||
latest_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
)
|
||||
latest_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id)
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
cc_pair.id
|
||||
)
|
||||
latest_successful_attempt = cc_pair_to_latest_successful_index_attempt.get(
|
||||
cc_pair.id
|
||||
)
|
||||
doc_count = cc_pair_to_document_cnt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id), 0
|
||||
)
|
||||
|
||||
return _get_connector_indexing_status_lite(
|
||||
cc_pair, latest_attempt, latest_finished_attempt, is_editable, doc_count
|
||||
cc_pair,
|
||||
latest_attempt,
|
||||
latest_finished_attempt,
|
||||
(
|
||||
latest_successful_attempt.time_started
|
||||
if latest_successful_attempt
|
||||
else None
|
||||
),
|
||||
is_editable,
|
||||
doc_count,
|
||||
)
|
||||
|
||||
# Process editable cc_pairs
|
||||
@@ -1384,6 +1421,7 @@ def _get_connector_indexing_status_lite(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
latest_index_attempt: IndexAttempt | None,
|
||||
latest_finished_index_attempt: IndexAttempt | None,
|
||||
last_successful_index_time: datetime | None,
|
||||
is_editable: bool,
|
||||
document_cnt: int,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
@@ -1417,7 +1455,7 @@ def _get_connector_indexing_status_lite(
|
||||
else None
|
||||
),
|
||||
last_status=latest_index_attempt.status if latest_index_attempt else None,
|
||||
last_success=cc_pair.last_successful_index_time,
|
||||
last_success=last_successful_index_time,
|
||||
docs_indexed=document_cnt,
|
||||
latest_index_attempt_docs_indexed=(
|
||||
latest_index_attempt.total_docs_indexed if latest_index_attempt else None
|
||||
|
||||
@@ -330,6 +330,7 @@ class CCPairFullInfo(BaseModel):
|
||||
num_docs_indexed: int, # not ideal, but this must be computed separately
|
||||
is_editable_for_current_user: bool,
|
||||
indexing: bool,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
last_permission_sync_attempt_status: PermissionSyncStatus | None = None,
|
||||
permission_syncing: bool = False,
|
||||
last_permission_sync_attempt_finished: datetime | None = None,
|
||||
@@ -382,9 +383,7 @@ class CCPairFullInfo(BaseModel):
|
||||
creator_email=(
|
||||
cc_pair_model.creator.email if cc_pair_model.creator else None
|
||||
),
|
||||
last_indexed=(
|
||||
last_index_attempt.time_started if last_index_attempt else None
|
||||
),
|
||||
last_indexed=last_successful_index_time,
|
||||
last_pruned=cc_pair_model.last_pruned,
|
||||
last_full_permission_sync=cls._get_last_full_permission_sync(cc_pair_model),
|
||||
overall_indexing_speed=overall_indexing_speed,
|
||||
|
||||
@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -50,6 +52,9 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
|
||||
@@ -127,6 +132,49 @@ class DeleteFileResponse(BaseModel):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
|
||||
"""True if either the filename or the content-type indicates a PDF.
|
||||
|
||||
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
|
||||
``Content-Type: application/octet-stream``), so we also fall back to
|
||||
extension-based detection via ``mimetypes.guess_type`` on the filename.
|
||||
"""
|
||||
if content_type == "application/pdf":
|
||||
return True
|
||||
guessed, _ = mimetypes.guess_type(filename)
|
||||
return guessed == "application/pdf"
|
||||
|
||||
|
||||
def _check_pdf_image_caps(
|
||||
filename: str, content: bytes, content_type: str | None, batch_total: int
|
||||
) -> int:
|
||||
"""Enforce per-file and per-batch embedded-image caps for PDFs.
|
||||
|
||||
Returns the number of embedded images in this file (0 for non-PDFs) so
|
||||
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
|
||||
if either cap is exceeded.
|
||||
"""
|
||||
if not _looks_like_pdf(filename, content_type):
|
||||
return 0
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Short-circuit at the larger cap so we get a useful count for both checks.
|
||||
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"PDF '{filename}' contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting the document into smaller files.",
|
||||
)
|
||||
if batch_total + count > batch_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Upload would exceed the {batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading fewer image-heavy files at once.",
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def _sanitize_path(path: str) -> str:
|
||||
"""Sanitize a file path, removing traversal attempts and normalizing.
|
||||
|
||||
@@ -356,6 +404,7 @@ async def upload_files(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Sanitize the base path
|
||||
@@ -375,6 +424,14 @@ async def upload_files(
|
||||
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024*1024)}MB",
|
||||
)
|
||||
|
||||
# Reject PDFs with an unreasonable per-file or per-batch image count
|
||||
batch_image_total += _check_pdf_image_caps(
|
||||
filename=file.filename or "unnamed",
|
||||
content=content,
|
||||
content_type=file.content_type,
|
||||
batch_total=batch_image_total,
|
||||
)
|
||||
|
||||
# Validate cumulative storage (existing + this upload batch)
|
||||
total_size += file_size
|
||||
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
|
||||
@@ -473,6 +530,7 @@ async def upload_zip(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
|
||||
# Extract zip contents into a subfolder named after the zip file
|
||||
zip_name = api_sanitize_filename(file.filename or "upload")
|
||||
@@ -511,6 +569,36 @@ async def upload_zip(
|
||||
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
|
||||
continue
|
||||
|
||||
# Skip PDFs that would trip the per-file or per-batch image
|
||||
# cap (would OOM the user-file-processing worker). Matches
|
||||
# /upload behavior but uses skip-and-warn to stay consistent
|
||||
# with the zip path's handling of oversized files.
|
||||
zip_file_name = zip_info.filename.split("/")[-1]
|
||||
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
|
||||
if zip_content_type == "application/pdf":
|
||||
image_count = count_pdf_embedded_images(
|
||||
BytesIO(file_content),
|
||||
max(
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
),
|
||||
)
|
||||
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
|
||||
logger.warning(
|
||||
"Skipping '%s' - exceeds %d per-file embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
)
|
||||
continue
|
||||
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
|
||||
logger.warning(
|
||||
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
)
|
||||
continue
|
||||
batch_image_total += image_count
|
||||
|
||||
total_size += file_size
|
||||
|
||||
# Validate cumulative storage
|
||||
|
||||
@@ -10,7 +10,10 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -156,6 +159,11 @@ def categorize_uploaded_files(
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
|
||||
# Running total of embedded images across PDFs in this batch. Once the
|
||||
# aggregate cap is reached, subsequent PDFs in the same upload are
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -204,6 +212,47 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if batch_image_total + count > batch_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Upload would exceed the "
|
||||
f"{batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading "
|
||||
f"fewer image-heavy files at once."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
batch_image_total += count
|
||||
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
@@ -65,17 +68,20 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_embedding_model
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
@@ -97,6 +103,71 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -445,16 +516,17 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# Before the upsert, check if this provider currently owns the global
|
||||
# CHAT default. The upsert may cascade-delete model_configurations
|
||||
# (and their flow mappings), so we need to remember this beforehand.
|
||||
was_default_provider = False
|
||||
if existing_provider and transitioning_to_auto_mode:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
was_default_provider = (
|
||||
current_default is not None
|
||||
and current_default.llm_provider_id == existing_provider.id
|
||||
)
|
||||
# When transitioning to auto mode, preserve existing model configurations
|
||||
# so the upsert doesn't try to delete them (which would trip the default
|
||||
# model protection guard). sync_auto_mode_models will handle the model
|
||||
# lifecycle afterward — adding new models, hiding removed ones, and
|
||||
# updating the default. This is safe even if sync fails: the provider
|
||||
# keeps its old models and default rather than losing them.
|
||||
if transitioning_to_auto_mode and existing_provider:
|
||||
llm_provider_upsert_request.model_configurations = [
|
||||
ModelConfigurationUpsertRequest.from_model(mc)
|
||||
for mc in existing_provider.model_configurations
|
||||
]
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
@@ -468,7 +540,6 @@ def put_llm_provider(
|
||||
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
@@ -478,20 +549,6 @@ def put_llm_provider(
|
||||
updated_provider,
|
||||
config,
|
||||
)
|
||||
|
||||
# If this provider was the default before the transition,
|
||||
# restore the default using the recommended model.
|
||||
if was_default_provider:
|
||||
recommended = config.get_default_model(
|
||||
llm_provider_upsert_request.provider
|
||||
)
|
||||
if recommended:
|
||||
update_default_provider(
|
||||
provider_id=updated_provider.id,
|
||||
model_name=recommended.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Refresh result with synced models
|
||||
result = LLMProviderView.from_model(updated_provider)
|
||||
|
||||
@@ -976,27 +1033,20 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1114,41 +1164,35 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
headers: dict[str, str] = {
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1171,8 +1215,12 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1223,27 +1271,20 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1272,13 +1313,23 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
stored_base = (
|
||||
(existing_provider.api_base or "")
|
||||
.strip()
|
||||
.rstrip("/")
|
||||
.removesuffix("/v1")
|
||||
)
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1337,26 +1388,128 @@ def get_lm_studio_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_details.id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -366,3 +366,18 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_embedding_model(model_name: str) -> bool:
|
||||
"""Checks for if a model is an embedding model"""
|
||||
from litellm import get_model_info
|
||||
|
||||
try:
|
||||
# get_model_info raises on unknown models
|
||||
# default to False
|
||||
model_info = get_model_info(model_name)
|
||||
except Exception:
|
||||
return False
|
||||
is_embedding_mode = model_info.get("mode") == "embedding"
|
||||
|
||||
return is_embedding_mode
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.citation_utils import extract_citation_order_from_text
|
||||
@@ -20,7 +22,9 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderResult
|
||||
from onyx.server.query_and_chat.streaming_models import FileReaderStart
|
||||
@@ -180,24 +184,37 @@ def create_custom_tool_packets(
|
||||
tab_index: int = 0,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
error: CustomToolErrorInfo | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
tool_id: int | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
tool_id=tool_id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -657,13 +674,55 @@ def translate_assistant_message_to_packets(
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
# Try to parse as structured CustomToolCallSummary JSON
|
||||
custom_data: dict | list | str | int | float | bool | None = (
|
||||
tool_call.tool_call_response
|
||||
)
|
||||
custom_error: CustomToolErrorInfo | None = None
|
||||
custom_response_type = "text"
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_call.tool_call_response)
|
||||
if isinstance(parsed, dict) and "tool_name" in parsed:
|
||||
custom_data = parsed.get("tool_result")
|
||||
custom_response_type = parsed.get(
|
||||
"response_type", "text"
|
||||
)
|
||||
if parsed.get("error"):
|
||||
custom_error = CustomToolErrorInfo(
|
||||
**parsed["error"]
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
ValidationError,
|
||||
):
|
||||
pass
|
||||
|
||||
custom_file_ids: list[str] | None = None
|
||||
if custom_response_type in ("image", "csv") and isinstance(
|
||||
custom_data, dict
|
||||
):
|
||||
custom_file_ids = custom_data.get("file_ids")
|
||||
custom_data = None
|
||||
|
||||
custom_args = {
|
||||
k: v
|
||||
for k, v in (tool_call.tool_call_arguments or {}).items()
|
||||
if k != "requestBody"
|
||||
}
|
||||
turn_tool_packets.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type="text",
|
||||
response_type=custom_response_type,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
data=tool_call.tool_call_response,
|
||||
data=custom_data,
|
||||
file_ids=custom_file_ids,
|
||||
error=custom_error,
|
||||
tool_args=custom_args if custom_args else None,
|
||||
tool_id=tool_call.tool_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class StreamingType(Enum):
|
||||
PYTHON_TOOL_START = "python_tool_start"
|
||||
PYTHON_TOOL_DELTA = "python_tool_delta"
|
||||
CUSTOM_TOOL_START = "custom_tool_start"
|
||||
CUSTOM_TOOL_ARGS = "custom_tool_args"
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta"
|
||||
FILE_READER_START = "file_reader_start"
|
||||
FILE_READER_RESULT = "file_reader_result"
|
||||
@@ -245,6 +246,20 @@ class CustomToolStart(BaseObj):
|
||||
type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
|
||||
|
||||
class CustomToolArgs(BaseObj):
|
||||
type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolErrorInfo(BaseModel):
|
||||
is_auth_error: bool = False
|
||||
status_code: int
|
||||
message: str
|
||||
|
||||
|
||||
# The allowed streamed packets for a custom tool
|
||||
@@ -252,11 +267,13 @@ class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
|
||||
|
||||
tool_name: str
|
||||
tool_id: int | None = None
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
################################################
|
||||
@@ -366,6 +383,7 @@ PacketObj = Union[
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolArgs,
|
||||
CustomToolDelta,
|
||||
FileReaderStart,
|
||||
FileReaderResult,
|
||||
|
||||
@@ -8,8 +8,6 @@ from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
@@ -165,39 +163,6 @@ def create_image_generation_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_custom_tool_packets(
|
||||
tool_name: str,
|
||||
response_type: str,
|
||||
turn_index: int,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
|
||||
@@ -275,9 +275,13 @@ def setup_postgres(db_session: Session) -> None:
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
try:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to upsert LLM provider during setup: %s", e)
|
||||
return
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
@@ -61,6 +62,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
|
||||
tool_result: Any # The response data
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallKickoff(BaseModel):
|
||||
|
||||
@@ -15,7 +15,9 @@ from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolArgs
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
@@ -139,7 +141,7 @@ class CustomTool(Tool[None]):
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolStart(tool_name=self._name),
|
||||
obj=CustomToolStart(tool_name=self._name, tool_id=self._id),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -149,10 +151,8 @@ class CustomTool(Tool[None]):
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
|
||||
# Build path params
|
||||
path_params = {}
|
||||
|
||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||
param_name = path_param_schema["name"]
|
||||
if param_name not in llm_kwargs:
|
||||
@@ -165,6 +165,7 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
path_params[param_name] = llm_kwargs[param_name]
|
||||
|
||||
# Build query params
|
||||
query_params = {}
|
||||
for query_param_schema in self._method_spec.get_query_param_schemas():
|
||||
if query_param_schema["name"] in llm_kwargs:
|
||||
@@ -172,6 +173,20 @@ class CustomTool(Tool[None]):
|
||||
query_param_schema["name"]
|
||||
]
|
||||
|
||||
# Emit args packet (path + query params only, no request body)
|
||||
tool_args = {**path_params, **query_params}
|
||||
if tool_args:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=CustomToolArgs(
|
||||
tool_name=self._name,
|
||||
tool_args=tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
request_body = llm_kwargs.get(REQUEST_BODY)
|
||||
url = self._method_spec.build_url(self._base_url, path_params, query_params)
|
||||
method = self._method_spec.method
|
||||
|
||||
@@ -180,6 +195,18 @@ class CustomTool(Tool[None]):
|
||||
)
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
# Detect HTTP errors — only 401/403 are flagged as auth errors
|
||||
error_info: CustomToolErrorInfo | None = None
|
||||
if response.status_code in (401, 403):
|
||||
error_info = CustomToolErrorInfo(
|
||||
is_auth_error=True,
|
||||
status_code=response.status_code,
|
||||
message=f"{self._name} action failed because of authentication error",
|
||||
)
|
||||
logger.warning(
|
||||
f"Auth error from custom tool '{self._name}': HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
tool_result: Any
|
||||
response_type: str
|
||||
file_ids: List[str] | None = None
|
||||
@@ -222,9 +249,11 @@ class CustomTool(Tool[None]):
|
||||
placement=placement,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
tool_id=self._id,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
error=error_info,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -236,6 +265,7 @@ class CustomTool(Tool[None]):
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
error=error_info,
|
||||
),
|
||||
llm_facing_response=llm_facing_response,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -85,6 +86,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
def __init__(self, tool_id: int, emitter: Emitter) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self._id = tool_id
|
||||
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
|
||||
# the same file on every tool call iteration within the same agent session.
|
||||
# Filename is included in the key so two files with identical bytes but
|
||||
# different names each get their own upload slot.
|
||||
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
|
||||
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
|
||||
# iterations, typically a few minutes), so stale-ID eviction is not needed.
|
||||
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -184,8 +193,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
content_hash = hashlib.sha256(chat_file.content).hexdigest()
|
||||
cache_key = (file_name, content_hash)
|
||||
ci_file_id = self._uploaded_file_cache.get(cache_key)
|
||||
if ci_file_id is None:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
self._uploaded_file_cache[cache_key] = ci_file_id
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
@@ -303,15 +317,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
f"file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged "
|
||||
f"file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
# Note: staged input files are intentionally not deleted here because
|
||||
# _uploaded_file_cache reuses their file_ids across iterations. They are
|
||||
# orphaned when the session ends, but the code interpreter cleans up
|
||||
# stale files on its own TTL.
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
|
||||
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
427
backend/onyx/utils/jsonriver/parse.py
Normal file
427
backend/onyx/utils/jsonriver/parse.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if (
|
||||
prev_val
|
||||
and len(cur_val) >= len(prev_val)
|
||||
and cur_val[len(prev_val) - 1] != prev_val[-1]
|
||||
):
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if (
|
||||
prev
|
||||
and len(current) >= len(prev)
|
||||
and current[len(prev) - 1] != prev[-1]
|
||||
):
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
514
backend/onyx/utils/jsonriver/tokenize.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -24,6 +24,9 @@ class OnyxVersion:
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def unset_ee(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
@@ -229,7 +229,9 @@ distro==1.9.0
|
||||
dnspython==2.8.0
|
||||
# via email-validator
|
||||
docstring-parser==0.17.0
|
||||
# via cyclopts
|
||||
# via
|
||||
# cyclopts
|
||||
# google-cloud-aiplatform
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
@@ -294,7 +296,13 @@ gitdb==4.0.12
|
||||
gitpython==3.1.45
|
||||
# via braintrust
|
||||
google-api-core==2.28.1
|
||||
# via google-api-python-client
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-api-python-client==2.86.0
|
||||
# via onyx
|
||||
google-auth==2.48.0
|
||||
@@ -303,6 +311,11 @@ google-auth==2.48.0
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
# google-auth-oauthlib
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
@@ -311,16 +324,51 @@ google-auth-httplib2==0.1.0
|
||||
# onyx
|
||||
google-auth-oauthlib==1.0.0
|
||||
# via onyx
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
greenlet==3.2.4
|
||||
# via
|
||||
# playwright
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -621,6 +669,8 @@ packaging==24.2
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# kombu
|
||||
@@ -670,12 +720,19 @@ propcache==0.4.1
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via google-api-core
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
@@ -713,6 +770,7 @@ pydantic==2.11.7
|
||||
# exa-py
|
||||
# fastapi
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# langchain-core
|
||||
# langfuse
|
||||
@@ -776,6 +834,7 @@ python-dateutil==2.8.2
|
||||
# botocore
|
||||
# celery
|
||||
# dateparser
|
||||
# google-cloud-bigquery
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
@@ -867,6 +926,8 @@ requests==2.32.5
|
||||
# dropbox
|
||||
# exa-py
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# huggingface-hub
|
||||
@@ -1054,7 +1115,9 @@ typing-extensions==4.15.0
|
||||
# exa-py
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
|
||||
@@ -115,6 +115,8 @@ distlib==0.4.0
|
||||
# via virtualenv
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
execnet==2.1.2
|
||||
@@ -143,14 +145,65 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -263,7 +316,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.3
|
||||
onyx-devtools==0.7.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -276,6 +329,8 @@ openapi-generator-cli==7.17.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# hatchling
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
@@ -318,6 +373,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via ipykernel
|
||||
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
@@ -339,6 +408,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -379,6 +449,7 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# jupyter-client
|
||||
# kubernetes
|
||||
# matplotlib
|
||||
@@ -406,13 +477,16 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.4.3
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -525,7 +599,9 @@ typing-extensions==4.15.0
|
||||
# celery-types
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mcp
|
||||
|
||||
@@ -86,6 +86,8 @@ discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.133.1
|
||||
@@ -102,12 +104,63 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -178,7 +231,10 @@ openai==2.14.0
|
||||
# litellm
|
||||
# onyx
|
||||
packaging==24.2
|
||||
# via huggingface-hub
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
@@ -193,6 +249,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
@@ -208,6 +278,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -224,6 +295,7 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
# posthog
|
||||
python-dotenv==1.1.1
|
||||
@@ -247,6 +319,9 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -306,7 +381,9 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
@@ -102,6 +102,8 @@ discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
@@ -127,12 +129,63 @@ fsspec==2025.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -262,6 +315,8 @@ openai==2.14.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# kombu
|
||||
# transformers
|
||||
@@ -281,6 +336,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
@@ -298,6 +367,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -315,6 +385,7 @@ python-dateutil==2.8.2
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# celery
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
@@ -341,6 +412,9 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
@@ -433,7 +507,9 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
86
backend/tests/README.md
Normal file
86
backend/tests/README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Backend Tests
|
||||
|
||||
## Test Types
|
||||
|
||||
There are four test categories, ordered by increasing scope:
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
|
||||
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
|
||||
logic (e.g. citation processing, encryption).
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
|
||||
|
||||
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
|
||||
application containers are not. Tests call functions directly and can mock selectively.
|
||||
|
||||
Use when you need a real database or real API calls but want control over setup.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
|
||||
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright / E2E Tests (`web/tests/e2e/`)
|
||||
|
||||
Full stack including web server. Use for frontend-backend coordination.
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
## Shared Fixtures
|
||||
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Running Tests Repeatedly (`pytest-repeat`)
|
||||
|
||||
Use `pytest-repeat` to catch flaky tests by running them multiple times:
|
||||
|
||||
```bash
|
||||
# Run a specific test 50 times
|
||||
pytest --count=50 backend/tests/unit/path/to/test.py::test_name
|
||||
|
||||
# Stop on first failure with -x
|
||||
pytest --count=50 -x backend/tests/unit/path/to/test.py::test_name
|
||||
|
||||
# Repeat an entire test file
|
||||
pytest --count=10 backend/tests/unit/path/to/test_file.py
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Enables EE mode for a test, with proper teardown and cache clearing.
|
||||
|
||||
```python
|
||||
# Whole file (in a test module, NOT in conftest.py)
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Whole directory — add an autouse wrapper to the directory's conftest.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
|
||||
# Single test
|
||||
def test_something(enable_ee: None) -> None: ...
|
||||
```
|
||||
|
||||
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
|
||||
directory — it only affects tests defined in the conftest itself (which is none).
|
||||
Use the autouse fixture wrapper pattern shown above instead.
|
||||
|
||||
Do NOT inline `global_version.set_ee()` — always use the fixture.
|
||||
24
backend/tests/conftest.py
Normal file
24
backend/tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Root conftest — shared fixtures available to all test directories."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def enable_ee() -> Generator[None, None, None]:
|
||||
"""Temporarily enable EE mode for a single test.
|
||||
|
||||
Restores the previous EE state and clears the versioned-implementation
|
||||
cache on teardown so state doesn't leak between tests.
|
||||
"""
|
||||
was_ee = global_version.is_ee_version()
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
if not was_ee:
|
||||
global_version.unset_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
def test_confluence_connector_permissions(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
confluence_connector: ConfluenceConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
|
||||
def test_confluence_connector_restriction_handling(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
mock_db_provider_class: MagicMock,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Test space key
|
||||
test_space_key = "DailyPermS"
|
||||
|
||||
@@ -4,8 +4,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
@@ -14,14 +12,3 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
@@ -98,7 +98,7 @@ def _build_connector(
|
||||
|
||||
def test_gdrive_perm_sync_with_real_data(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
|
||||
@@ -19,16 +17,7 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.teams.models import TeamsThread
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
@@ -168,18 +166,9 @@ def test_slim_docs_retrieval_from_teams_connector(
|
||||
_assert_is_valid_external_access(external_access=slim_doc.external_access)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for perm sync tests to work since
|
||||
perm syncing is an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
yield
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_perm_sync(
|
||||
teams_connector: TeamsConnector,
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
enable_ee: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.
|
||||
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
External dependency unit tests for user file delete queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_for_user_file_delete work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in DELETING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that delete_user_file_impl clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_delete_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_delete_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_delete,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_delete,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_deleting_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in DELETING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.DELETING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "del_bp_user")
|
||||
_create_deleting_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=USER_FILE_DELETE_MAX_QUEUE_DEPTH + 1),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestDeletePerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "del_guard_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "del_guard_set_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert (
|
||||
0 < ttl <= CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
), f"Guard key TTL {ttl}s is outside the expected range (0, {CELERY_USER_FILE_DELETE_TASK_EXPIRES}]"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestDeleteTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "del_expires_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.DELETE_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_DELETE
|
||||
assert (
|
||||
call.kwargs.get("expires") == CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
), "Task must be submitted with the correct expires value to prevent stale task accumulation"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestDeleteWorkerClearsGuardKey:
|
||||
"""process_single_user_file_delete removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so delete_user_file_impl returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file delete lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_delete_lock_key(user_file_id)
|
||||
delete_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = delete_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the delete lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file_delete.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if delete_lock.owned():
|
||||
delete_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -14,13 +15,14 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
# In order to get these tests to run, use the credentials from Bitwarden.
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
class DocExternalAccessSet(BaseModel):
|
||||
"""A version of DocExternalAccess that uses sets for comparison."""
|
||||
@@ -52,9 +54,6 @@ def test_jira_doc_sync(
|
||||
This test uses the AS project which has applicationRole permission,
|
||||
meaning all documents should be marked as public.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use AS project specifically for this test
|
||||
connector_config = {
|
||||
@@ -150,9 +149,6 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
This test uses a project that has specific user permissions to verify
|
||||
that specific users are correctly extracted.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use SUP project which has specific user permissions
|
||||
connector_config = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
@@ -18,6 +19,8 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Expected groups from the danswerai.atlassian.net Jira instance
|
||||
# Note: These groups are shared with Confluence since they're both Atlassian products
|
||||
# App accounts (bots, integrations) are filtered out
|
||||
|
||||
@@ -158,7 +158,7 @@ class TestLLMConfigurationEndpoint:
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -540,7 +540,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -585,7 +585,7 @@ class TestDefaultProviderEndpoint:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -247,7 +247,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -350,7 +350,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -386,7 +386,7 @@ class TestLLMProviderChanges:
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1152,3 +1152,179 @@ class TestAutoModeTransitionsAndResync:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_updates_default_when_recommended_default_changes(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and a sync arrives with a
|
||||
different recommended default model (both models still in config),
|
||||
the global default should be updated to the new recommendation.
|
||||
|
||||
Steps:
|
||||
1. Create auto-mode provider with config v1: default=gpt-4o.
|
||||
2. Set gpt-4o as the global CHAT default.
|
||||
3. Re-sync with config v2: default=gpt-4o-mini (gpt-4o still present).
|
||||
4. Verify the CHAT default switched to gpt-4o-mini and both models
|
||||
remain visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o as the global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Re-sync with config v2 (recommended default changed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0, "Sync should report changes when default switches"
|
||||
|
||||
# Both models should remain visible
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is True
|
||||
assert visibility["gpt-4o-mini"] is True
|
||||
|
||||
# The CHAT default should now be gpt-4o-mini
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None
|
||||
assert (
|
||||
default_after.name == "gpt-4o-mini"
|
||||
), f"Default should be updated to 'gpt-4o-mini', got '{default_after.name}'"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_idempotent_when_default_already_matches(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and it already matches the
|
||||
recommended default, re-syncing should report zero changes.
|
||||
|
||||
This is a regression test for the bug where changes was unconditionally
|
||||
incremented even when the default was already correct.
|
||||
"""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o (the recommended default) as global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# First sync to stabilize state
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Second sync — default already matches, should be a no-op
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert changes == 0, (
|
||||
f"Expected 0 changes when default already matches recommended, "
|
||||
f"got {changes}"
|
||||
)
|
||||
|
||||
# Default should still be gpt-4o
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == "gpt-4o"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
This should act as the main point of reference for testing that default model
|
||||
logic is consisten.
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
|
||||
def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
models: list[ModelConfigurationUpsertRequest] | None = None,
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider with multiple models."""
|
||||
if models is None:
|
||||
models = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True, supports_image_input=False
|
||||
),
|
||||
]
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=models,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_provider(db_session: Session, name: str) -> None:
|
||||
"""Helper to clean up a test provider by name."""
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if provider:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_name(db_session: Session) -> Generator[str, None, None]:
|
||||
"""Generate a unique provider name for each test, with automatic cleanup."""
|
||||
name = f"test-provider-{uuid4().hex[:8]}"
|
||||
yield name
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, name)
|
||||
|
||||
|
||||
class TestDefaultModelProtection:
|
||||
"""Tests that the default model cannot be removed or hidden."""
|
||||
|
||||
def test_cannot_remove_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default text model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to update the provider without the default model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_hide_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Setting is_visible=False on the default text model should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to hide the default model
|
||||
with pytest.raises(ValueError, match="Cannot hide the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=False
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_remove_default_vision_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default vision model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
# Set gpt-4o as both the text and vision default
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
update_default_vision_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to remove the default vision model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_can_remove_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Remove gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_names = {mc.name for mc in updated.model_configurations}
|
||||
assert "gpt-4o" in model_names
|
||||
assert "gpt-4o-mini" not in model_names
|
||||
|
||||
def test_can_hide_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Hiding a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Hide gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=False
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in updated.model_configurations
|
||||
}
|
||||
assert model_visibility["gpt-4o"] is True
|
||||
assert model_visibility["gpt-4o-mini"] is False
|
||||
@@ -1218,15 +1218,16 @@ def test_code_interpreter_receives_chat_files(
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
|
||||
# Verify: file uploaded, code executed via streaming, staged file cleaned up
|
||||
# Verify: file uploaded and code executed via streaming.
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
|
||||
delete_requests = mock_ci_server.get_requests(method="DELETE")
|
||||
assert len(delete_requests) == 1
|
||||
assert delete_requests[0].path.startswith("/v1/files/")
|
||||
# Staged input files are intentionally NOT deleted — PythonTool caches their
|
||||
# file IDs across agent-loop iterations to avoid re-uploading on every call.
|
||||
# The code interpreter cleans them up via its own TTL.
|
||||
assert len(mock_ci_server.get_requests(method="DELETE")) == 0
|
||||
|
||||
execute_body = mock_ci_server.get_requests(
|
||||
method="POST", path="/v1/execute/stream"
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Integration tests for the "Last Indexed" time displayed on both the
|
||||
per-connector detail page and the all-connectors listing page.
|
||||
|
||||
Expected behavior: "Last Indexed" = time_started of the most recent
|
||||
successful index attempt for the cc pair, regardless of pagination.
|
||||
|
||||
Edge cases:
|
||||
1. First page of index attempts is entirely errors — last_indexed should
|
||||
still reflect the older successful attempt beyond page 1.
|
||||
2. Credential swap — successful attempts, then failures after a
|
||||
"credential change"; last_indexed should reflect the most recent
|
||||
successful attempt.
|
||||
3. Mix of statuses — only the most recent successful attempt matters.
|
||||
4. COMPLETED_WITH_ERRORS counts as a success for last_indexed purposes.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.server.documents.models import CCPairFullInfo
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLite
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _wait_for_real_success(
|
||||
cc_pair: DATestCCPair,
|
||||
admin: DATestUser,
|
||||
) -> None:
|
||||
"""Wait for the initial index attempt to complete successfully."""
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair,
|
||||
after=datetime(2000, 1, 1, tzinfo=timezone.utc),
|
||||
user_performing_action=admin,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
|
||||
def _get_detail(cc_pair_id: int, admin: DATestUser) -> CCPairFullInfo:
|
||||
result = CCPairManager.get_single(cc_pair_id, admin)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
|
||||
def _get_listing(cc_pair_id: int, admin: DATestUser) -> ConnectorIndexingStatusLite:
|
||||
result = CCPairManager.get_indexing_status_by_id(cc_pair_id, admin)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
|
||||
def test_last_indexed_first_page_all_errors(reset: None) -> None: # noqa: ARG001
|
||||
"""When the first page of index attempts is entirely errors but an
|
||||
older successful attempt exists, both the detail page and the listing
|
||||
page should still show the time of that successful attempt.
|
||||
|
||||
The detail page UI uses page size 8. We insert 10 failed attempts
|
||||
more recent than the initial success to push the success off page 1.
|
||||
"""
|
||||
admin = UserManager.create(name="admin_first_page_errors")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
# Baseline: last_success should be set from the initial successful run
|
||||
listing_before = _get_listing(cc_pair.id, admin)
|
||||
assert listing_before.last_success is not None
|
||||
|
||||
# 10 recent failures push the success off page 1
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="simulated failure",
|
||||
base_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert (
|
||||
detail.last_indexed is not None
|
||||
), "Detail page last_indexed is None even though a successful attempt exists"
|
||||
assert (
|
||||
listing.last_success is not None
|
||||
), "Listing page last_success is None even though a successful attempt exists"
|
||||
|
||||
# Both surfaces must agree
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_credential_swap_scenario(reset: None) -> None: # noqa: ARG001
|
||||
"""Perform an actual credential swap: create connector + cred1 (cc_pair_1),
|
||||
wait for success, then associate a new cred2 with the same connector
|
||||
(cc_pair_2), wait for that to succeed, and inject failures on cc_pair_2.
|
||||
|
||||
cc_pair_2's last_indexed must reflect cc_pair_2's own success, not
|
||||
cc_pair_1's older one. Both the detail page and listing page must agree.
|
||||
"""
|
||||
admin = UserManager.create(name="admin_cred_swap")
|
||||
|
||||
connector = ConnectorManager.create(user_performing_action=admin)
|
||||
cred1 = CredentialManager.create(user_performing_action=admin)
|
||||
cc_pair_1 = CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=cred1.id,
|
||||
user_performing_action=admin,
|
||||
)
|
||||
_wait_for_real_success(cc_pair_1, admin)
|
||||
|
||||
cred2 = CredentialManager.create(user_performing_action=admin, name="swapped-cred")
|
||||
cc_pair_2 = CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=cred2.id,
|
||||
user_performing_action=admin,
|
||||
)
|
||||
_wait_for_real_success(cc_pair_2, admin)
|
||||
|
||||
listing_after_swap = _get_listing(cc_pair_2.id, admin)
|
||||
assert listing_after_swap.last_success is not None
|
||||
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair_2.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="credential expired",
|
||||
base_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair_2.id, admin)
|
||||
listing = _get_listing(cc_pair_2.id, admin)
|
||||
|
||||
assert detail.last_indexed is not None
|
||||
assert listing.last_success is not None
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_mixed_statuses(reset: None) -> None: # noqa: ARG001
|
||||
"""Mix of in_progress, failed, and successful attempts. Only the most
|
||||
recent successful attempt's time matters."""
|
||||
admin = UserManager.create(name="admin_mixed")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Success 5 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
base_time=now - timedelta(hours=5),
|
||||
)
|
||||
|
||||
# Failures 3 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=3,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="transient failure",
|
||||
base_time=now - timedelta(hours=3),
|
||||
)
|
||||
|
||||
# In-progress 1 hour ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.IN_PROGRESS,
|
||||
base_time=now - timedelta(hours=1),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert detail.last_indexed is not None
|
||||
assert listing.last_success is not None
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_completed_with_errors(reset: None) -> None: # noqa: ARG001
|
||||
"""COMPLETED_WITH_ERRORS is treated as a successful attempt (matching
|
||||
IndexingStatus.is_successful()). When it is the most recent "success"
|
||||
and later attempts all failed, both surfaces should reflect its time."""
|
||||
admin = UserManager.create(name="admin_completed_errors")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# COMPLETED_WITH_ERRORS 2 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.COMPLETED_WITH_ERRORS,
|
||||
base_time=now - timedelta(hours=2),
|
||||
)
|
||||
|
||||
# 10 failures after — push everything else off page 1
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="post-partial failure",
|
||||
base_time=now,
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert (
|
||||
detail.last_indexed is not None
|
||||
), "COMPLETED_WITH_ERRORS should count as a success for last_indexed"
|
||||
assert (
|
||||
listing.last_success is not None
|
||||
), "COMPLETED_WITH_ERRORS should count as a success for last_success"
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
@@ -427,7 +427,7 @@ def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
@@ -674,7 +674,7 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=base_payload,
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
@@ -711,7 +711,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "not currently supported" in response.json()["message"]
|
||||
assert "not currently supported" in response.json()["detail"]
|
||||
|
||||
# Verify no duplicate was created — only the original provider should exist
|
||||
provider = _get_provider_by_id(admin_user, provider_id)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
|
||||
|
||||
# Should return 403 Forbidden
|
||||
assert response.status_code == 403
|
||||
assert "don't have access to this assistant" in response.json()["message"]
|
||||
assert "don't have access to this assistant" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_authorized_persona_access_returns_filtered_providers(
|
||||
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
|
||||
|
||||
# Should return 404
|
||||
assert response.status_code == 404
|
||||
assert "Persona not found" in response.json()["message"]
|
||||
assert "Persona not found" in response.json()["detail"]
|
||||
|
||||
8
backend/tests/unit/ee/conftest.py
Normal file
8
backend/tests/unit/ee/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Auto-enable EE mode for all tests under tests/unit/ee/."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
@@ -107,7 +107,7 @@ class TestCreateCheckoutSession:
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Stripe error"
|
||||
assert exc_info.value.detail == "Stripe error"
|
||||
|
||||
|
||||
class TestCreateCustomerPortalSession:
|
||||
@@ -137,7 +137,7 @@ class TestCreateCustomerPortalSession:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert exc_info.value.detail == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.create_portal_service")
|
||||
@@ -243,7 +243,7 @@ class TestUpdateSeats:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert exc_info.value.detail == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.get_used_seats")
|
||||
@@ -317,7 +317,7 @@ class TestUpdateSeats:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Cannot reduce below 10 seats"
|
||||
assert exc_info.value.detail == "Cannot reduce below 10 seats"
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@@ -346,7 +346,7 @@ class TestCircuitBreaker:
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SERVICE_UNAVAILABLE
|
||||
assert "Connect to Stripe" in exc_info.value.message
|
||||
assert "Connect to Stripe" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestMakeBillingRequest:
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Bad request" in exc_info.value.message
|
||||
assert "Bad request" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.service._get_headers")
|
||||
@@ -152,7 +152,7 @@ class TestMakeBillingRequest:
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Failed to connect" in exc_info.value.message
|
||||
assert "Failed to connect" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateCheckoutSession:
|
||||
|
||||
@@ -72,7 +72,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert exc_info.value.detail == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -97,7 +97,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert exc_info.value.detail == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -118,7 +118,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Failed to fetch Stripe publishable key"
|
||||
assert exc_info.value.detail == "Failed to fetch Stripe publishable key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -132,7 +132,7 @@ class TestGetStripePublishableKey:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert "not configured" in exc_info.value.message
|
||||
assert "not configured" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
||||
@@ -9,6 +9,8 @@ from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_cc_pair(
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Unit tests for SharepointConnector._fetch_site_pages error handling.
|
||||
|
||||
Covers 404 handling (classic sites / no modern pages) and 400
|
||||
canvasLayout fallback (corrupt pages causing $expand=canvasLayout to
|
||||
fail on the LIST endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_INVALID_REQUEST_CODE
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
|
||||
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
|
||||
PAGES_COLLECTION = f"https://graph.microsoft.com/v1.0/sites/{FAKE_SITE_ID}/pages"
|
||||
SITE_PAGES_BASE = f"{PAGES_COLLECTION}/microsoft.graph.sitePage"
|
||||
|
||||
|
||||
def _site_descriptor() -> SiteDescriptor:
|
||||
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
|
||||
|
||||
|
||||
def _make_http_error(
|
||||
status_code: int,
|
||||
error_code: str = "itemNotFound",
|
||||
message: str = "Item not found",
|
||||
) -> HTTPError:
|
||||
body = {"error": {"code": error_code, "message": message}}
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response._content = json.dumps(body).encode()
|
||||
response.headers["Content-Type"] = "application/json"
|
||||
return HTTPError(response=response)
|
||||
|
||||
|
||||
def _setup_connector(
|
||||
monkeypatch: pytest.MonkeyPatch, # noqa: ARG001
|
||||
) -> SharepointConnector:
|
||||
"""Create a connector with the graph client and site resolution mocked."""
|
||||
connector = SharepointConnector(sites=[SITE_URL])
|
||||
connector.graph_api_base = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
mock_sites = type(
|
||||
"FakeSites",
|
||||
(),
|
||||
{
|
||||
"get_by_url": staticmethod(
|
||||
lambda url: type( # noqa: ARG005
|
||||
"Q",
|
||||
(),
|
||||
{
|
||||
"execute_query": lambda self: None, # noqa: ARG005
|
||||
"id": FAKE_SITE_ID,
|
||||
},
|
||||
)()
|
||||
),
|
||||
},
|
||||
)()
|
||||
connector._graph_client = type("FakeGraphClient", (), {"sites": mock_sites})()
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def _patch_graph_api_get_json(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake_fn: Any,
|
||||
) -> None:
|
||||
monkeypatch.setattr(SharepointConnector, "_graph_api_get_json", fake_fn)
|
||||
|
||||
|
||||
class TestFetchSitePages404:
|
||||
def test_404_yields_no_pages(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 from the Pages API should result in zero yielded pages."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert pages == []
|
||||
|
||||
def test_404_does_not_raise(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 must not propagate as an exception."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
for _ in connector._fetch_site_pages(_site_descriptor()):
|
||||
pass
|
||||
|
||||
def test_non_404_http_error_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Non-404 HTTP errors (e.g. 403) must still propagate."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(403)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
def test_successful_fetch_yields_pages(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the API succeeds, pages should be yielded normally."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
fake_page = {
|
||||
"id": "page-1",
|
||||
"title": "Hello World",
|
||||
"webUrl": f"{SITE_URL}/SitePages/Hello.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
return {"value": [fake_page]}
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
def test_404_on_second_page_stops_pagination(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first API page succeeds but a nextLink returns 404,
|
||||
already-yielded pages are kept and iteration stops cleanly."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
first_page = {
|
||||
"id": "page-1",
|
||||
"title": "First",
|
||||
"webUrl": f"{SITE_URL}/SitePages/First.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return {
|
||||
"value": [first_page],
|
||||
"@odata.nextLink": "https://graph.microsoft.com/next",
|
||||
}
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
|
||||
class TestFetchSitePages400Fallback:
|
||||
"""When $expand=canvasLayout on the LIST endpoint returns 400
|
||||
invalidRequest, _fetch_site_pages should fall back to listing
|
||||
without expansion, then expanding each page individually."""
|
||||
|
||||
GOOD_PAGE: dict[str, Any] = {
|
||||
"id": "good-1",
|
||||
"name": "Good.aspx",
|
||||
"title": "Good Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
BAD_PAGE: dict[str, Any] = {
|
||||
"id": "bad-1",
|
||||
"name": "Bad.aspx",
|
||||
"title": "Bad Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
GOOD_PAGE_EXPANDED: dict[str, Any] = {
|
||||
**GOOD_PAGE,
|
||||
"canvasLayout": {"horizontalSections": []},
|
||||
}
|
||||
|
||||
def test_fallback_expands_good_pages_individually(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""On 400 from the LIST expand, the connector should list without
|
||||
expand, then GET each page individually with $expand=canvasLayout."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
bad_page = self.BAD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0].get("canvasLayout") is not None
|
||||
assert pages[1].get("canvasLayout") is None
|
||||
assert pages[1]["id"] == "bad-1"
|
||||
|
||||
def test_mid_pagination_400_does_not_duplicate(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first paginated batch succeeds but a later nextLink
|
||||
returns 400, pages from the first batch must not be re-yielded
|
||||
by the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
bad_page = self.BAD_PAGE
|
||||
second_page = {
|
||||
"id": "page-2",
|
||||
"name": "Second.aspx",
|
||||
"title": "Second Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
next_link = "https://graph.microsoft.com/v1.0/next-page-link"
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
return {
|
||||
"value": [good_page],
|
||||
"@odata.nextLink": next_link,
|
||||
}
|
||||
if url == next_link:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page, second_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == f"{PAGES_COLLECTION}/page-2/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return {**second_page, "canvasLayout": {"horizontalSections": []}}
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
ids = [p["id"] for p in pages]
|
||||
assert ids == ["good-1", "bad-1", "page-2"]
|
||||
|
||||
def test_non_invalid_request_400_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A 400 with a different error code (not invalidRequest) should
|
||||
propagate, not trigger the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(400, "badRequest", "Something else went wrong")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
models = [
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Unit test verifying that the upload API path sends tasks with expires=.
|
||||
|
||||
The upload_files_to_user_files_with_indexing function must include expires=
|
||||
on every send_task call to prevent phantom task accumulation if the worker
|
||||
is down or slow.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
|
||||
|
||||
def _make_mock_user_file() -> MagicMock:
|
||||
uf = MagicMock(spec=UserFile)
|
||||
uf.id = str(uuid4())
|
||||
return uf
|
||||
|
||||
|
||||
@patch("onyx.db.projects.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.db.projects.create_user_files")
|
||||
@patch(
|
||||
"onyx.background.celery.versioned_apps.client.app",
|
||||
new_callable=MagicMock,
|
||||
)
|
||||
def test_send_task_includes_expires(
|
||||
mock_client_app: MagicMock,
|
||||
mock_create: MagicMock,
|
||||
mock_tenant: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Every send_task call from the upload path must include expires=."""
|
||||
user_files = [_make_mock_user_file(), _make_mock_user_file()]
|
||||
mock_create.return_value = MagicMock(
|
||||
user_files=user_files,
|
||||
rejected_files=[],
|
||||
id_to_temp_id={},
|
||||
)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_db_session = MagicMock()
|
||||
|
||||
upload_files_to_user_files_with_indexing(
|
||||
files=[],
|
||||
project_id=None,
|
||||
user=mock_user,
|
||||
temp_id_map=None,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert mock_client_app.send_task.call_count == len(user_files)
|
||||
|
||||
for call in mock_client_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires") == CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), "send_task must include expires= to prevent phantom task accumulation"
|
||||
@@ -15,12 +15,12 @@ class TestOnyxError:
|
||||
def test_basic_construction(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
assert err.error_code is OnyxErrorCode.NOT_FOUND
|
||||
assert err.message == "Session not found"
|
||||
assert err.detail == "Session not found"
|
||||
assert err.status_code == 404
|
||||
|
||||
def test_message_defaults_to_code(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
assert err.message == "UNAUTHENTICATED"
|
||||
assert err.detail == "UNAUTHENTICATED"
|
||||
assert str(err) == "UNAUTHENTICATED"
|
||||
|
||||
def test_status_code_override(self) -> None:
|
||||
@@ -73,18 +73,18 @@ class TestExceptionHandler:
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "NOT_FOUND"
|
||||
assert body["message"] == "Thing not found"
|
||||
assert body["detail"] == "Thing not found"
|
||||
|
||||
def test_status_code_override_in_response(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-override")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "BAD_GATEWAY"
|
||||
assert body["message"] == "upstream 503"
|
||||
assert body["detail"] == "upstream 503"
|
||||
|
||||
def test_default_message(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-default-msg")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "UNAUTHENTICATED"
|
||||
assert body["message"] == "UNAUTHENTICATED"
|
||||
assert body["detail"] == "UNAUTHENTICATED"
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Unit tests for image summarization error handling.
|
||||
|
||||
Verifies that:
|
||||
1. LLM errors produce actionable error messages (not base64 dumps)
|
||||
2. Unsupported MIME type logs include the magic bytes and size
|
||||
3. The ValueError raised on LLM failure preserves the original exception
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing.image_summarization import _summarize_image
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
|
||||
|
||||
class TestSummarizeImageErrorMessage:
|
||||
"""_summarize_image must not dump base64 image data into error messages."""
|
||||
|
||||
def test_error_message_contains_exception_type_not_base64(self) -> None:
|
||||
"""The ValueError should contain the original exception info, not message payloads."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("Connection timeout")
|
||||
|
||||
# A fake base64-encoded image string (should NOT appear in the error)
|
||||
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg..."
|
||||
|
||||
with pytest.raises(ValueError, match="RuntimeError: Connection timeout"):
|
||||
_summarize_image(fake_encoded, mock_llm, query="test")
|
||||
|
||||
def test_error_message_does_not_contain_base64(self) -> None:
|
||||
"""Ensure base64 data is never included in the error message."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("API error")
|
||||
|
||||
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image(fake_encoded, mock_llm)
|
||||
|
||||
error_str = str(exc_info.value)
|
||||
assert "base64" not in error_str
|
||||
assert "iVBOR" not in error_str
|
||||
|
||||
def test_original_exception_is_chained(self) -> None:
|
||||
"""The ValueError should chain the original exception via __cause__."""
|
||||
mock_llm = MagicMock()
|
||||
original = RuntimeError("upstream failure")
|
||||
mock_llm.invoke.side_effect = original
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
assert exc_info.value.__cause__ is original
|
||||
|
||||
|
||||
class TestUnsupportedMimeTypeLogging:
|
||||
"""summarize_image_with_error_handling should log useful info for unsupported formats."""
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.image_summarization.summarize_image_pipeline",
|
||||
side_effect=__import__(
|
||||
"onyx.file_processing.image_summarization",
|
||||
fromlist=["UnsupportedImageFormatError"],
|
||||
).UnsupportedImageFormatError("unsupported"),
|
||||
)
|
||||
def test_logs_magic_bytes_and_size(
|
||||
self, mock_pipeline: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
"""The info log should include magic bytes hex and image size."""
|
||||
mock_llm = MagicMock()
|
||||
# TIFF magic bytes (not in the supported list)
|
||||
image_data = b"\x49\x49\x2a\x00" + b"\x00" * 100
|
||||
|
||||
with patch("onyx.file_processing.image_summarization.logger") as mock_logger:
|
||||
result = summarize_image_with_error_handling(
|
||||
llm=mock_llm,
|
||||
image_data=image_data,
|
||||
context_name="test_image.tiff",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_logger.info.assert_called_once()
|
||||
log_args = mock_logger.info.call_args
|
||||
# Check the format string args contain magic bytes and size
|
||||
assert "49492a00" in str(log_args)
|
||||
assert "104" in str(log_args) # 4 + 100 bytes
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Unit tests verifying that LiteLLM error details are extracted and surfaced
|
||||
in image summarization error messages.
|
||||
|
||||
When the LLM call fails, the error handler should include the status_code,
|
||||
llm_provider, and model from LiteLLM exceptions so operators can diagnose
|
||||
the root cause (rate limit, content filter, unsupported vision, etc.)
|
||||
without needing to dig through LiteLLM internals.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing.image_summarization import _summarize_image
|
||||
|
||||
|
||||
def _make_litellm_style_error(
|
||||
*,
|
||||
message: str = "API error",
|
||||
status_code: int | None = None,
|
||||
llm_provider: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> RuntimeError:
|
||||
"""Create an exception with LiteLLM-style attributes."""
|
||||
exc = RuntimeError(message)
|
||||
if status_code is not None:
|
||||
exc.status_code = status_code # type: ignore[attr-defined]
|
||||
if llm_provider is not None:
|
||||
exc.llm_provider = llm_provider # type: ignore[attr-defined]
|
||||
if model is not None:
|
||||
exc.model = model # type: ignore[attr-defined]
|
||||
return exc
|
||||
|
||||
|
||||
class TestLiteLLMErrorExtraction:
|
||||
"""Verify that LiteLLM error attributes are included in the ValueError."""
|
||||
|
||||
def test_status_code_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Content filter triggered",
|
||||
status_code=400,
|
||||
llm_provider="azure",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="status_code=400"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_llm_provider_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Bad request",
|
||||
status_code=400,
|
||||
llm_provider="azure",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="llm_provider=azure"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_model_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Bad request",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="model=gpt-4o"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_all_fields_in_single_message(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Rate limit exceeded",
|
||||
status_code=429,
|
||||
llm_provider="azure",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "status_code=429" in msg
|
||||
assert "llm_provider=azure" in msg
|
||||
assert "model=gpt-4o" in msg
|
||||
assert "Rate limit exceeded" in msg
|
||||
|
||||
def test_plain_exception_without_litellm_attrs(self) -> None:
|
||||
"""Non-LiteLLM exceptions should still produce a useful message."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = ConnectionError("Connection refused")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "ConnectionError" in msg
|
||||
assert "Connection refused" in msg
|
||||
# Should not contain status_code/llm_provider/model
|
||||
assert "status_code" not in msg
|
||||
assert "llm_provider" not in msg
|
||||
|
||||
def test_no_base64_in_error(self) -> None:
|
||||
"""Error messages must not contain the full base64 image payload.
|
||||
|
||||
Some LiteLLM exceptions echo the request body (including base64 images)
|
||||
in their message. The truncation guard ensures the bulk of such a
|
||||
payload is stripped from the re-raised ValueError.
|
||||
"""
|
||||
mock_llm = MagicMock()
|
||||
# Build a long base64-like payload that exceeds the 512-char truncation
|
||||
fake_b64_payload = "iVBORw0KGgo" * 100 # ~1100 chars
|
||||
fake_b64 = f"data:image/png;base64,{fake_b64_payload}"
|
||||
|
||||
mock_llm.invoke.side_effect = RuntimeError(
|
||||
f"Request failed for payload: {fake_b64}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image(fake_b64, mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
# The full payload must not appear (truncation should have kicked in)
|
||||
assert fake_b64_payload not in msg
|
||||
assert "truncated" in msg
|
||||
|
||||
def test_long_error_message_truncated(self) -> None:
|
||||
"""Exception messages longer than 512 chars are truncated."""
|
||||
mock_llm = MagicMock()
|
||||
long_msg = "x" * 1000
|
||||
mock_llm.invoke.side_effect = RuntimeError(long_msg)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "truncated" in msg
|
||||
# The full 1000-char string should not appear
|
||||
assert long_msg not in msg
|
||||
@@ -28,6 +28,7 @@ from onyx.llm.utils import get_max_input_tokens
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -26,14 +26,6 @@ class TestIsTrueOpenAIModel:
|
||||
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
|
||||
|
||||
def test_real_openai_o1_preview(self) -> None:
|
||||
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
|
||||
|
||||
def test_real_openai_o1_mini(self) -> None:
|
||||
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
|
||||
|
||||
def test_openai_with_provider_prefix(self) -> None:
|
||||
"""Test that OpenAI model with provider prefix is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False
|
||||
|
||||
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for Slack channel reference resolution and tag filtering
|
||||
in handle_regular_answer.py."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_client_with_channels(
|
||||
channel_map: dict[str, str],
|
||||
) -> MagicMock:
|
||||
"""Return a mock WebClient where conversations_info resolves IDs to names."""
|
||||
client = MagicMock()
|
||||
|
||||
def _conversations_info(channel: str) -> MagicMock:
|
||||
if channel in channel_map:
|
||||
resp = MagicMock()
|
||||
resp.validate = MagicMock()
|
||||
resp.__getitem__ = lambda _self, key: {
|
||||
"channel": {
|
||||
"name": channel_map[channel],
|
||||
"is_im": False,
|
||||
"is_mpim": False,
|
||||
}
|
||||
}[key]
|
||||
return resp
|
||||
raise SlackApiError("channel_not_found", response=MagicMock())
|
||||
|
||||
client.conversations_info = _conversations_info
|
||||
return client
|
||||
|
||||
|
||||
def _mock_logger() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SLACK_CHANNEL_REF_PATTERN regex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackChannelRefPattern:
|
||||
def test_matches_bare_channel_id(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
|
||||
assert matches == [("C097NBWMY8Y", "")]
|
||||
|
||||
def test_matches_channel_id_with_name(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
|
||||
assert matches == [("C097NBWMY8Y", "eng-infra")]
|
||||
|
||||
def test_matches_multiple_channels(self) -> None:
|
||||
msg = "compare <#C111AAA> and <#C222BBB|general>"
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
|
||||
assert len(matches) == 2
|
||||
assert ("C111AAA", "") in matches
|
||||
assert ("C222BBB", "general") in matches
|
||||
|
||||
def test_no_match_on_plain_text(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
|
||||
assert matches == []
|
||||
|
||||
def test_no_match_on_user_mention(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
|
||||
assert matches == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_channel_references tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveChannelReferences:
|
||||
def test_resolves_bare_channel_id_via_api(self) -> None:
|
||||
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summary of <#C097NBWMY8Y> this week",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summary of #eng-infra this week"
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
|
||||
|
||||
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#C097NBWMY8Y|eng-infra> for updates",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "check #eng-infra for updates"
|
||||
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
|
||||
# Should NOT have called the API since name was in the markup
|
||||
client.conversations_info.assert_not_called()
|
||||
|
||||
def test_multiple_channels(self) -> None:
|
||||
client = _mock_client_with_channels(
|
||||
{
|
||||
"C111AAA": "eng-infra",
|
||||
"C222BBB": "eng-general",
|
||||
}
|
||||
)
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#eng-general" in message
|
||||
assert "<#" not in message
|
||||
assert len(tags) == 2
|
||||
tag_values = {t.tag_value for t in tags}
|
||||
assert tag_values == {"eng-infra", "eng-general"}
|
||||
|
||||
def test_no_channel_references_returns_unchanged(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="just a normal message with no channels",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "just a normal message with no channels"
|
||||
assert tags == []
|
||||
|
||||
def test_api_failure_skips_channel_gracefully(self) -> None:
|
||||
# Client that fails for all channel lookups
|
||||
client = _mock_client_with_channels({})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Message should remain unchanged for the failed channel
|
||||
assert "<#CBADID123>" in message
|
||||
assert tags == []
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
def test_partial_failure_resolves_what_it_can(self) -> None:
|
||||
# Only one of two channels resolves
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "<#CBADID123>" in message # failed one stays raw
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_duplicate_channel_produces_single_tag(self) -> None:
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summarize <#C111AAA> and compare with <#C111AAA>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summarize #eng-infra and compare with #eng-infra"
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_mixed_pipe_and_bare_formats(self) -> None:
|
||||
client = _mock_client_with_channels({"C222BBB": "random"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="see <#C111AAA|eng-infra> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#random" in message
|
||||
assert len(tags) == 2
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Tests for LLM model fetch endpoints.
|
||||
|
||||
These tests verify the full request/response flow for fetching models
|
||||
from dynamic providers (Ollama, OpenRouter), including the
|
||||
from dynamic providers (Ollama, OpenRouter, Litellm), including the
|
||||
sync-to-DB behavior when provider_name is specified.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
@@ -499,6 +503,7 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
@@ -614,3 +619,283 @@ class TestGetLMStudioAvailableModels:
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetLitellmAvailableModels:
|
||||
"""Tests for the Litellm proxy model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response(self) -> dict:
|
||||
"""Mock response from Litellm /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1700000001,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "gemini-pro",
|
||||
"object": "model",
|
||||
"created": 1700000002,
|
||||
"owned_by": "google",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted model list."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
|
||||
|
||||
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that provider_name and model_name are correctly extracted."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
gpt = next(r for r in results if r.model_name == "gpt-4o")
|
||||
assert gpt.provider_name == "openai"
|
||||
|
||||
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
|
||||
assert claude.provider_name == "anthropic"
|
||||
|
||||
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that results are alphabetically sorted by model_name."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
model_names = [r.model_name for r in results]
|
||||
assert model_names == sorted(model_names, key=str.lower)
|
||||
|
||||
def test_empty_data_raises_onyx_error(self) -> None:
|
||||
"""Test that empty model list raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No models found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_missing_data_key_raises_onyx_error(self) -> None:
|
||||
"""Test that response without 'data' key raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_skips_unparseable_entries(self) -> None:
|
||||
"""Test that malformed model entries are skipped without failing."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_with_bad_entry = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
# Missing required fields
|
||||
{"bad_field": "bad_value"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_with_bad_entry
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].model_name == "gpt-4o"
|
||||
|
||||
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
|
||||
"""Test that OnyxError is raised when all entries fail to parse."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_all_bad = {
|
||||
"data": [
|
||||
{"bad_field": "bad_value"},
|
||||
{"another_bad": 123},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_all_bad
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No compatible models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_api_base_trailing_slash_handled(self) -> None:
|
||||
"""Test that trailing slashes in api_base are handled correctly."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_litellm_response = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000/",
|
||||
api_key="test-key",
|
||||
)
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should call /v1/models without double slashes
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://localhost:4000/v1/models"
|
||||
|
||||
def test_connection_failure_raises_onyx_error(self) -> None:
|
||||
"""Test that connection failures are wrapped in OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
"""Test that a 401 response raises OnyxError with authentication message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="bad-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Authentication failed"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_404_raises_not_found_error(self) -> None:
|
||||
"""Test that a 404 response raises OnyxError with endpoint not found message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_embedding_model
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
@@ -209,3 +210,35 @@ class TestIsReasoningModel:
|
||||
is_reasoning_model("anthropic/claude-3-5-sonnet", "Claude 3.5 Sonnet")
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestIsEmbeddingModel:
|
||||
"""Tests for embedding model detection."""
|
||||
|
||||
def test_openai_embedding_ada(self) -> None:
|
||||
assert is_embedding_model("text-embedding-ada-002") is True
|
||||
|
||||
def test_openai_embedding_3_small(self) -> None:
|
||||
assert is_embedding_model("text-embedding-3-small") is True
|
||||
|
||||
def test_openai_embedding_3_large(self) -> None:
|
||||
assert is_embedding_model("text-embedding-3-large") is True
|
||||
|
||||
def test_cohere_embed_model(self) -> None:
|
||||
assert is_embedding_model("embed-english-v3.0") is True
|
||||
|
||||
def test_bedrock_titan_embed(self) -> None:
|
||||
assert is_embedding_model("amazon.titan-embed-text-v1") is True
|
||||
|
||||
def test_gpt4o_not_embedding(self) -> None:
|
||||
assert is_embedding_model("gpt-4o") is False
|
||||
|
||||
def test_gpt4_not_embedding(self) -> None:
|
||||
assert is_embedding_model("gpt-4") is False
|
||||
|
||||
def test_dall_e_not_embedding(self) -> None:
|
||||
assert is_embedding_model("dall-e-3") is False
|
||||
|
||||
def test_unknown_custom_model_not_embedding(self) -> None:
|
||||
"""Custom/local models not in litellm's model DB should default to False."""
|
||||
assert is_embedding_model("my-custom-local-model-v1") is False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user