mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-06 16:15:46 +00:00
Compare commits
46 Commits
table-prim
...
nikg/std-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16eca436ad | ||
|
|
5176fd7386 | ||
|
|
92538084e9 | ||
|
|
2d996e05a4 | ||
|
|
b2956f795b | ||
|
|
b272085543 | ||
|
|
8193aa4fd0 | ||
|
|
52db41a00b | ||
|
|
f1cf3c4589 | ||
|
|
5322aeed90 | ||
|
|
5da8870fd2 | ||
|
|
57d3ab3b40 | ||
|
|
649c7fe8b9 | ||
|
|
e5e2bc6149 | ||
|
|
b148065e1d | ||
|
|
367808951c | ||
|
|
0f74da3302 | ||
|
|
96f7cbd25a | ||
|
|
c627cea17d | ||
|
|
a8cdc3965d | ||
|
|
60891b2f44 | ||
|
|
d2f35e1fae | ||
|
|
7a7350f387 | ||
|
|
8ef504acd5 | ||
|
|
0dbabfe445 | ||
|
|
50575d0f6b | ||
|
|
9862fbd4a6 | ||
|
|
003d94546a | ||
|
|
01d3473974 | ||
|
|
19c7809a43 | ||
|
|
98e6346152 | ||
|
|
c63fdf1c13 | ||
|
|
49b509a0a7 | ||
|
|
2b1f1fe311 | ||
|
|
3e67ea9df7 | ||
|
|
98e3602dd6 | ||
|
|
4fded5b0a1 | ||
|
|
328c305d26 | ||
|
|
f902727215 | ||
|
|
69c8aa08b3 | ||
|
|
c98aa486e4 | ||
|
|
03553114c5 | ||
|
|
6532c94230 | ||
|
|
1b32a7d94e | ||
|
|
5fd0fe192b | ||
|
|
1de522f9ae |
@@ -15,6 +15,7 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
52
.github/workflows/pr-integration-tests.yml
vendored
52
.github/workflows/pr-integration-tests.yml
vendored
@@ -335,7 +335,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
@@ -471,13 +470,13 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
onyx-lite-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"run-id=${{ github.run_id }}-onyx-lite-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
@@ -495,13 +494,12 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
- name: Create .env file for Onyx Lite Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
@@ -509,28 +507,23 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
# Start only the services needed for Onyx Lite (Postgres + API server)
|
||||
- name: Start Docker containers (onyx-lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
id: start_docker_onyx_lite
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
echo "Starting wait-for-service script (onyx-lite)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -552,14 +545,14 @@ jobs:
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
- name: Run Onyx Lite Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
echo "Running onyx-lite integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -570,39 +563,38 @@ jobs:
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
- name: Dump API server logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
- name: Dump all-container logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
- name: Upload logs (onyx-lite)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
- name: Stop Docker containers (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
@@ -744,7 +736,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
108
.github/workflows/pr-playwright-tests.yml
vendored
108
.github/workflows/pr-playwright-tests.yml
vendored
@@ -268,10 +268,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -279,6 +280,7 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -590,6 +592,108 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -686,7 +790,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests]
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
43
.vscode/launch.json
vendored
43
.vscode/launch.json
vendored
@@ -40,19 +40,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"name": "Celery",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -253,35 +241,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
|
||||
70
AGENTS.md
70
AGENTS.md
@@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
@@ -617,6 +586,45 @@ Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -11,11 +11,10 @@ from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -142,7 +141,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
Get license metadata from cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
@@ -150,38 +149,34 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cached = cache.get(LICENSE_METADATA_KEY)
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
try:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
Deletes the cached LicenseMetadata. The actual license in the database
|
||||
is not affected. Delete is idempotent — if the key doesn't exist, this
|
||||
is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cache.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
@@ -192,7 +187,7 @@ def update_license_cache(
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
Update the cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
@@ -211,7 +206,7 @@ def update_license_cache(
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
@@ -230,7 +225,7 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
LICENSE_METADATA_KEY,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from operator import and_
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
@@ -15,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -36,6 +36,8 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -165,18 +167,12 @@ def validate_object_creation_for_user(
|
||||
if object_is_public and user.role == UserRole.BASIC:
|
||||
detail = "User does not have permission to create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INSUFFICIENT_PERMISSIONS, detail)
|
||||
|
||||
if not target_group_ids:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, detail)
|
||||
|
||||
user_curated_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
@@ -189,10 +185,7 @@ def validate_object_creation_for_user(
|
||||
if not target_group_ids_set.issubset(user_curated_group_ids):
|
||||
detail = "Curators cannot control groups they don't curate"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INSUFFICIENT_PERMISSIONS, detail)
|
||||
|
||||
|
||||
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
|
||||
@@ -471,7 +464,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -774,8 +769,7 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -26,7 +26,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -42,7 +41,6 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
@@ -58,6 +56,8 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -169,26 +169,23 @@ async def create_checkout_session(
|
||||
if seats is not None:
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -206,18 +203,15 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -240,9 +234,9 @@ async def get_billing_information(
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE,
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -250,11 +244,11 @@ async def get_billing_information(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
_open_billing_circuit()
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -274,31 +268,25 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -329,18 +317,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -351,17 +339,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -31,15 +33,6 @@ logger = setup_logger()
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
@@ -101,7 +94,7 @@ async def _make_billing_request(
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
OnyxError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
@@ -128,11 +121,17 @@ async def _make_billing_request(
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
|
||||
)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
|
||||
@@ -223,6 +223,15 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -230,6 +239,7 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -35,6 +34,8 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -127,9 +128,9 @@ async def claim_license(
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -146,15 +147,16 @@ async def claim_license(
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found in database",
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -173,7 +175,7 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -199,12 +201,14 @@ async def claim_license(
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
except requests.RequestException:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -221,9 +225,9 @@ async def upload_license(
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -234,14 +238,14 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -297,9 +301,9 @@ async def delete_license(
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,6 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
@@ -56,6 +55,7 @@ from ee.onyx.configs.license_enforcement_config import (
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
# Fail open - don't block users due to cache connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
|
||||
@@ -365,6 +365,7 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
@@ -125,7 +126,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
# Fail closed - disable EE features if we can't verify license
|
||||
settings.ee_features_enabled = False
|
||||
|
||||
@@ -21,7 +21,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
@@ -43,6 +42,8 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -116,9 +117,14 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create customer portal session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -134,9 +140,14 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create checkout session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -147,15 +158,20 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create subscription session",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -186,18 +202,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -208,15 +224,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -153,3 +156,8 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
@@ -120,7 +120,6 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -201,13 +200,14 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -30,6 +30,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
transform_vespa_chunks_to_opensearch_chunks,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -47,6 +48,7 @@ from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -146,7 +148,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
@@ -161,6 +168,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
11
backend/onyx/cache/interface.py
vendored
11
backend/onyx/cache/interface.py
vendored
@@ -1,9 +1,20 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError)
|
||||
"""Exception types that represent transient cache connectivity / operational
|
||||
failures. Callers that want to fail-open (or fail-closed) on cache errors
|
||||
should catch this tuple instead of bare ``Exception``.
|
||||
|
||||
When adding a new ``CacheBackend`` implementation, add its transient error
|
||||
base class(es) here so all call-sites pick it up automatically."""
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
|
||||
@@ -52,6 +52,7 @@ from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -966,6 +967,13 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Extract generated_files if this is a code interpreter response
|
||||
generated_files = None
|
||||
if isinstance(tool_response.rich_response, PythonToolRichResponse):
|
||||
generated_files = (
|
||||
tool_response.rich_response.generated_files or None
|
||||
)
|
||||
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
@@ -1017,6 +1025,7 @@ def run_llm_loop(
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
generated_files=generated_files,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import mimetypes
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,14 +13,41 @@ from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.tools import create_tool_call_no_commit
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_referenced_file_descriptors(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
message_text: str,
|
||||
) -> list[FileDescriptor]:
|
||||
"""Extract FileDescriptors for code interpreter files referenced in the message text."""
|
||||
descriptors: list[FileDescriptor] = []
|
||||
for tool_call_info in tool_calls:
|
||||
if not tool_call_info.generated_files:
|
||||
continue
|
||||
for gen_file in tool_call_info.generated_files:
|
||||
file_id = (
|
||||
gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else ""
|
||||
)
|
||||
if file_id and file_id in message_text:
|
||||
mime_type, _ = mimetypes.guess_type(gen_file.filename)
|
||||
descriptors.append(
|
||||
FileDescriptor(
|
||||
id=file_id,
|
||||
type=mime_type_to_chat_file_type(mime_type),
|
||||
name=gen_file.filename,
|
||||
)
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -297,5 +325,14 @@ def save_chat_turn(
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if message_text:
|
||||
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -495,14 +495,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
# Individual worker concurrency settings
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
@@ -819,7 +812,9 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
@@ -900,6 +895,9 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import timezone
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
@@ -32,6 +31,8 @@ from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_best_persona_id_for_user
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
@@ -227,7 +228,9 @@ def duplicate_chat_session_for_user_from_slack(
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_session:
|
||||
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SESSION_NOT_FOUND, "Invalid Chat Session ID provided"
|
||||
)
|
||||
|
||||
# This enforces permissions and sets a default
|
||||
new_persona_id = get_best_persona_id_for_user(
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
@@ -32,6 +31,8 @@ from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -539,7 +540,7 @@ def add_credential_to_connector(
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not fetch_ee_implementation_or_noop(
|
||||
@@ -547,9 +548,9 @@ def add_credential_to_connector(
|
||||
"check_if_valid_sync_source",
|
||||
noop_return_value=True,
|
||||
)(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Connector of type {connector.source} does not support SYNC access type",
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
@@ -557,9 +558,9 @@ def add_credential_to_connector(
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=error_msg,
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
error_msg,
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
@@ -622,12 +623,12 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
"Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
association = get_connector_credential_pair_for_user(
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -246,6 +247,7 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
@@ -336,7 +338,8 @@ def update_document_set(
|
||||
)
|
||||
|
||||
document_set_row.description = document_set_update_request.description
|
||||
document_set_row.is_up_to_date = False
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any
|
||||
from typing import AsyncContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
@@ -28,6 +27,8 @@ from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.sql_engine import USE_IAM_AUTH
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -114,7 +115,7 @@ async def get_async_session(
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid tenant ID")
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
@@ -27,6 +26,8 @@ from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.db.engine.iam_auth import provide_iam_token
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -344,7 +345,7 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid tenant ID")
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
@@ -371,7 +372,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid tenant ID")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
yield db_session
|
||||
@@ -390,7 +391,7 @@ def get_db_readonly_user_session_with_current_tenant() -> (
|
||||
readonly_engine = get_readonly_sqlalchemy_engine()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid tenant ID")
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import delete
|
||||
@@ -26,6 +25,8 @@ from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -134,8 +135,9 @@ def update_document_boost_for_user(
|
||||
stmt = _add_user_filters(stmt, user, get_editable=True)
|
||||
result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document is not editable by this user"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Document is not editable by this user",
|
||||
)
|
||||
|
||||
result.boost = boost
|
||||
@@ -156,8 +158,9 @@ def update_document_hidden_for_user(
|
||||
stmt = _add_user_filters(stmt, user, get_editable=True)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document is not editable by this user"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Document is not editable by this user",
|
||||
)
|
||||
|
||||
result.hidden = hidden
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
@@ -11,6 +10,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.models import InputPrompt
|
||||
from onyx.db.models import InputPrompt__User
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -54,9 +55,9 @@ def insert_input_prompt(
|
||||
input_prompt = result.scalar_one_or_none()
|
||||
|
||||
if input_prompt is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
@@ -78,7 +79,7 @@ def update_input_prompt(
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, "You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
@@ -88,9 +89,9 @@ def update_input_prompt(
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
)
|
||||
|
||||
return input_prompt
|
||||
@@ -121,7 +122,7 @@ def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> Non
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
@@ -140,12 +141,13 @@ def remove_input_prompt(
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public and not delete_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot delete public prompts with this method",
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, "You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
@@ -167,7 +169,7 @@ def fetch_input_prompt_by_id(
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -532,6 +532,7 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import not_
|
||||
@@ -38,6 +37,8 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import MinimalPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
@@ -144,9 +145,9 @@ def fetch_persona_by_id_for_user(
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
persona = db_session.scalars(stmt).one_or_none()
|
||||
if not persona:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Persona with ID {persona_id} does not exist or user is not authorized to access it",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
f"Persona with ID {persona_id} does not exist or user is not authorized to access it",
|
||||
)
|
||||
return persona
|
||||
|
||||
@@ -315,7 +316,7 @@ def create_update_persona(
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@@ -20,6 +19,8 @@ from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files
|
||||
from onyx.server.features.projects.projects_file_utils import RejectedFile
|
||||
@@ -52,7 +53,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
@@ -110,7 +111,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
) -> CategorizedFilesResult:
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Project not found")
|
||||
|
||||
categorized_files_result = create_user_files(
|
||||
files,
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -24,6 +23,8 @@ from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -44,22 +45,22 @@ def validate_user_role_update(
|
||||
"""
|
||||
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.LIMITED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
@@ -67,40 +68,34 @@ def validate_user_role_update(
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Curator role must be set via the User Group Menu",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Curator role must be set via the User Group Menu",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.LIMITED:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to a Limited User role. "
|
||||
"This role is automatically assigned to users through certain endpoints in the API."
|
||||
),
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"A user cannot be set to a Limited User role. "
|
||||
"This role is automatically assigned to users through certain endpoints in the API.",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.SLACK_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to a Slack User role. "
|
||||
"This role is automatically assigned to users who only use Onyx via Slack."
|
||||
),
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"A user cannot be set to a Slack User role. "
|
||||
"This role is automatically assigned to users who only use Onyx via Slack.",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to an External Permissioned User role. "
|
||||
"This role is automatically assigned to users who have been "
|
||||
"pulled in to the system via an external permissions system."
|
||||
),
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"A user cannot be set to an External Permissioned User role. "
|
||||
"This role is automatically assigned to users who have been "
|
||||
"pulled in to the system via an external permissions system.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
@@ -18,6 +19,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -338,12 +340,18 @@ def get_all_chunks_paginated(
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
response: httpx.Response | None = None
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
|
||||
error_base = (
|
||||
f"Failed to get chunks from Vespa slice {slice_id} with continuation token "
|
||||
f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds."
|
||||
)
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
|
||||
@@ -52,7 +52,9 @@ def replace_invalid_doc_id_characters(text: str) -> str:
|
||||
return text.replace("'", "_")
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = True, timeout: int | None = None
|
||||
) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -64,7 +66,7 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
|
||||
else None
|
||||
),
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT),
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
0
backend/onyx/error_handling/__init__.py
Normal file
0
backend/onyx/error_handling/__init__.py
Normal file
101
backend/onyx/error_handling/error_codes.py
Normal file
101
backend/onyx/error_handling/error_codes.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Standardized error codes for the Onyx backend.
|
||||
|
||||
Usage:
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired")
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OnyxErrorCode(Enum):
|
||||
"""
|
||||
Each member is a tuple of (error_code_string, http_status_code).
|
||||
|
||||
The error_code_string is a stable, machine-readable identifier that
|
||||
API consumers can match on. The http_status_code is the default HTTP
|
||||
status to return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication (401)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
|
||||
INVALID_TOKEN = ("INVALID_TOKEN", 401)
|
||||
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
|
||||
CSRF_FAILURE = ("CSRF_FAILURE", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorization (403)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHORIZED = ("UNAUTHORIZED", 403)
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
# ------------------------------------------------------------------
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
# ------------------------------------------------------------------
|
||||
NOT_FOUND = ("NOT_FOUND", 404)
|
||||
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
|
||||
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
|
||||
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
# ------------------------------------------------------------------
|
||||
CONFLICT = ("CONFLICT", 409)
|
||||
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rate Limiting / Quotas (429 / 402)
|
||||
# ------------------------------------------------------------------
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
|
||||
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
|
||||
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server Errors (5xx)
|
||||
# ------------------------------------------------------------------
|
||||
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
|
||||
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
|
||||
def detail(self, message: str | None = None) -> dict[str, str]:
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
}
|
||||
82
backend/onyx/error_handling/exceptions.py
Normal file
82
backend/onyx/error_handling/exceptions.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""OnyxError — the single exception type for all Onyx business errors.
|
||||
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
For upstream errors with a dynamic HTTP status (e.g. billing service),
|
||||
use ``status_code_override``::
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=upstream_status,
|
||||
)
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxError(Exception):
|
||||
"""Structured error that maps to a specific ``OnyxErrorCode``.
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
self.error_code = error_code
|
||||
self.message = message or error_code.code
|
||||
self._status_code_override = status_code_override
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._status_code_override or self.error_code.status_code
|
||||
|
||||
|
||||
def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register a global handler that converts ``OnyxError`` to JSON responses.
|
||||
|
||||
Must be called *after* the app is created but *before* it starts serving.
|
||||
The handler logs at WARNING for 4xx and ERROR for 5xx.
|
||||
"""
|
||||
|
||||
@app.exception_handler(OnyxError)
|
||||
async def _handle_onyx_error(
|
||||
request: Request, # noqa: ARG001
|
||||
exc: OnyxError,
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
)
|
||||
@@ -59,6 +59,7 @@ from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
@@ -444,6 +445,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error
|
||||
)
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
@@ -146,6 +146,11 @@ class SlackRenderer(HTMLRenderer):
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
@@ -218,5 +223,48 @@ class SlackRenderer(HTMLRenderer):
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -92,6 +92,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -572,6 +573,43 @@ def _normalize_file_names_for_backwards_compatibility(
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
require_editable: bool,
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=require_editable,
|
||||
)
|
||||
if has_requested_access:
|
||||
return cc_pair
|
||||
|
||||
# Special case: global curators should be able to manage files
|
||||
# for public file connectors even when they are not the creator.
|
||||
if (
|
||||
require_editable
|
||||
and user.role == UserRole.GLOBAL_CURATOR
|
||||
and cc_pair.access_type == AccessType.PUBLIC
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -583,7 +621,7 @@ def upload_files_api(
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
@@ -596,6 +634,13 @@ def list_connector_files(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=False,
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
@@ -645,7 +690,7 @@ def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
@@ -663,12 +708,13 @@ def update_connector_files(
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
# and validate user permissions for file management.
|
||||
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=True,
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.11.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
|
||||
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -1133,7 +1133,8 @@ done
|
||||
# Already deleted
|
||||
service_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Service {service_name}: {e}")
|
||||
logger.error(f"Error deleting Service {service_name}: {e}")
|
||||
raise
|
||||
|
||||
pod_deleted = False
|
||||
try:
|
||||
@@ -1148,7 +1149,8 @@ done
|
||||
# Already deleted
|
||||
pod_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Pod {pod_name}: {e}")
|
||||
logger.error(f"Error deleting Pod {pod_name}: {e}")
|
||||
raise
|
||||
|
||||
# Wait for resources to be fully deleted to prevent 409 conflicts
|
||||
# on immediate re-provisioning
|
||||
|
||||
@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Prevent overlapping runs of this task
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import delete_document_set as db_delete_document_set
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import insert_document_set
|
||||
@@ -142,7 +143,10 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if DISABLE_VECTOR_DB:
|
||||
db_session.refresh(document_set)
|
||||
db_delete_document_set(document_set, db_session)
|
||||
else:
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -7,13 +7,14 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -116,7 +117,9 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -128,11 +131,11 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
llm = get_default_llm()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -11,6 +10,8 @@ from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.search_settings import get_current_db_embedding_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.indexing.models import EmbeddingModelDetail
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -59,7 +60,7 @@ def test_embedding_configuration(
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
|
||||
|
||||
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||
@@ -93,8 +94,9 @@ def delete_embedding_provider(
|
||||
embedding_provider is not None
|
||||
and provider_type == embedding_provider.provider_type
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"You can't delete a currently active model",
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, provider_type=provider_type)
|
||||
|
||||
@@ -11,7 +11,6 @@ from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -38,6 +37,8 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import user_can_access_persona
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
@@ -186,7 +187,7 @@ def _validate_llm_provider_change(
|
||||
Only enforced in MULTI_TENANT mode.
|
||||
|
||||
Raises:
|
||||
HTTPException: If api_base or custom_config changed without changing API key
|
||||
OnyxError: If api_base or custom_config changed without changing API key
|
||||
"""
|
||||
if not MULTI_TENANT or api_key_changed:
|
||||
return
|
||||
@@ -200,9 +201,9 @@ def _validate_llm_provider_change(
|
||||
)
|
||||
|
||||
if api_base_changed or custom_config_changed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base and/or custom config cannot be changed without changing the API key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base and/or custom config cannot be changed without changing the API key",
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +223,7 @@ def fetch_llm_provider_options(
|
||||
for well_known_llm in well_known_llms:
|
||||
if well_known_llm.name == provider_name:
|
||||
return well_known_llm
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
@@ -281,7 +282,7 @@ def test_llm_configuration(
|
||||
error_msg = test_llm(llm)
|
||||
|
||||
if error_msg:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -292,11 +293,11 @@ def test_default_provider(
|
||||
llm = get_default_llm()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=str(error))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
|
||||
|
||||
|
||||
@admin_router.get("/provider")
|
||||
@@ -362,35 +363,31 @@ def put_llm_provider(
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -415,9 +412,9 @@ def put_llm_provider(
|
||||
db_session, persona_ids
|
||||
)
|
||||
if missing_personas:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
)
|
||||
# Remove duplicates while preserving order
|
||||
seen: set[int] = set()
|
||||
@@ -473,7 +470,7 @@ def put_llm_provider(
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
@@ -483,19 +480,19 @@ def delete_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@@ -535,9 +532,9 @@ def get_auto_config(
|
||||
"""
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch configuration from GitHub",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Failed to fetch configuration from GitHub",
|
||||
)
|
||||
return config.model_dump()
|
||||
|
||||
@@ -694,13 +691,13 @@ def list_llm_providers_for_persona(
|
||||
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
|
||||
|
||||
# Verify user has access to this persona
|
||||
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have access to this assistant",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You don't have access to this assistant",
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
@@ -854,9 +851,9 @@ def get_bedrock_available_models(
|
||||
try:
|
||||
bedrock = session.client("bedrock")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
)
|
||||
|
||||
# Build model info dict from foundation models (modelId -> metadata)
|
||||
@@ -975,14 +972,14 @@ def get_bedrock_available_models(
|
||||
return results
|
||||
|
||||
except (ClientError, NoCredentialsError, BotoCoreError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to AWS Bedrock: {e}",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to connect to AWS Bedrock: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unexpected error fetching Bedrock models: {e}",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"Unexpected error fetching Bedrock models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -994,9 +991,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
@@ -1013,9 +1010,9 @@ def get_ollama_available_models(
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base URL is required to fetch Ollama models.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch Ollama models.",
|
||||
)
|
||||
|
||||
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
|
||||
@@ -1024,9 +1021,9 @@ def get_ollama_available_models(
|
||||
# with the same response format
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
@@ -1128,9 +1125,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch OpenRouter models: {e}",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch OpenRouter models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -1151,9 +1148,9 @@ def get_openrouter_available_models(
|
||||
|
||||
data = response_json.get("data", [])
|
||||
if not isinstance(data, list) or len(data) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your OpenRouter endpoint",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenRouter endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenRouterFinalModelResponse] = []
|
||||
@@ -1188,8 +1185,9 @@ def get_openrouter_available_models(
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No compatible models found from OpenRouter"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenRouter",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
@@ -93,6 +93,8 @@ class ToolResponse(BaseModel):
|
||||
# | WebContentResponse
|
||||
# This comes from custom tools, tool result needs to be saved
|
||||
| CustomToolCallSummary
|
||||
# This comes from code interpreter, carries generated files
|
||||
| PythonToolRichResponse
|
||||
# If the rich response is a string, this is what's saved to the tool call in the DB
|
||||
| str
|
||||
| None # If nothing needs to be persisted outside of the string value passed to the LLM
|
||||
@@ -193,6 +195,12 @@ class ChatFile(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PythonToolRichResponse(BaseModel):
|
||||
"""Rich response from the Python tool carrying generated files."""
|
||||
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
|
||||
|
||||
class PythonToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for the Python/Code Interpreter tool."""
|
||||
|
||||
@@ -245,6 +253,7 @@ class ToolCallInfo(BaseModel):
|
||||
tool_call_response: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
generated_files: list[PythonExecutionFile] | None = None
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -12,6 +15,9 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -80,6 +86,19 @@ class CodeInterpreterClient:
|
||||
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self) -> CodeInterpreterClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self.session.close()
|
||||
self._closed = True
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
@@ -98,16 +117,32 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
result = response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
|
||||
def execute(
|
||||
self,
|
||||
@@ -157,8 +192,11 @@ class CodeInterpreterClient:
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import LlmPythonExecutionResult
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
@@ -107,7 +108,11 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
with CodeInterpreterClient() as client:
|
||||
return client.health(use_cache=True)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -171,194 +176,203 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
# Create Code Interpreter client
|
||||
client = CodeInterpreterClient()
|
||||
# Create Code Interpreter client — context manager ensures
|
||||
# session.close() is called on every exit path.
|
||||
with CodeInterpreterClient() as client:
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=(
|
||||
event.data if event.stream == "stdout" else ""
|
||||
),
|
||||
stderr=(
|
||||
event.data if event.stream == "stderr" else ""
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file "
|
||||
f"{workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated "
|
||||
f"file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged "
|
||||
f"file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=(None if result_event.exit_code == 0 else truncated_stderr),
|
||||
)
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file {workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None, # No rich response needed for Python tool
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
@@ -57,6 +56,30 @@ def _sanitize_query(query: str) -> str:
|
||||
return " ".join(sanitized.split())
|
||||
|
||||
|
||||
def _normalize_queries_input(raw: Any) -> list[str]:
|
||||
"""Coerce LLM output to a list of sanitized query strings.
|
||||
|
||||
Accepts a bare string or a list (possibly with non-string elements).
|
||||
Sanitizes each query (strip control chars, normalize whitespace) and
|
||||
drops empty or whitespace-only entries.
|
||||
"""
|
||||
if isinstance(raw, str):
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
return []
|
||||
raw = [raw]
|
||||
elif not isinstance(raw, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for q in raw:
|
||||
if q is None:
|
||||
continue
|
||||
sanitized = _sanitize_query(str(q))
|
||||
if sanitized:
|
||||
result.append(sanitized)
|
||||
return result
|
||||
|
||||
|
||||
class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
NAME = "web_search"
|
||||
DESCRIPTION = "Search the web for information."
|
||||
@@ -189,13 +212,7 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
f'like: {{"queries": ["your search query here"]}}'
|
||||
),
|
||||
)
|
||||
raw_queries = cast(list[str], llm_kwargs[QUERIES_FIELD])
|
||||
|
||||
# Normalize queries:
|
||||
# - remove control characters (null bytes, etc.) that LLMs sometimes produce
|
||||
# - collapse whitespace and strip
|
||||
# - drop empty/whitespace-only queries
|
||||
queries = [sanitized for q in raw_queries if (sanitized := _sanitize_query(q))]
|
||||
queries = _normalize_queries_input(llm_kwargs[QUERIES_FIELD])
|
||||
if not queries:
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
|
||||
@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.1
|
||||
nltk==3.9.3
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
# via
|
||||
|
||||
@@ -16,10 +16,6 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
|
||||
|
||||
def run_jobs() -> None:
|
||||
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
|
||||
use_lightweight = True
|
||||
|
||||
# command setup
|
||||
cmd_worker_primary = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -74,6 +70,48 @@ def run_jobs() -> None:
|
||||
"--queues=connector_doc_fetching",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
]
|
||||
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,user_file_delete",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -82,144 +120,31 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
|
||||
# Prepare background worker commands based on mode
|
||||
if use_lightweight:
|
||||
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
|
||||
cmd_worker_background = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
|
||||
]
|
||||
background_workers = [("BACKGROUND", cmd_worker_background)]
|
||||
else:
|
||||
print("Starting workers in STANDARD mode (separate background workers)")
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,sandbox",
|
||||
]
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
|
||||
]
|
||||
background_workers = [
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
]
|
||||
all_workers = [
|
||||
("PRIMARY", cmd_worker_primary),
|
||||
("LIGHT", cmd_worker_light),
|
||||
("DOCPROCESSING", cmd_worker_docprocessing),
|
||||
("DOCFETCHING", cmd_worker_docfetching),
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
("BEAT", cmd_beat),
|
||||
]
|
||||
|
||||
# spawn processes
|
||||
worker_primary_process = subprocess.Popen(
|
||||
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_light_process = subprocess.Popen(
|
||||
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_docprocessing_process = subprocess.Popen(
|
||||
cmd_worker_docprocessing,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
# Spawn background worker processes based on mode
|
||||
background_processes = []
|
||||
for name, cmd in background_workers:
|
||||
processes = []
|
||||
for name, cmd in all_workers:
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
background_processes.append((name, process))
|
||||
processes.append((name, process))
|
||||
|
||||
# monitor threads
|
||||
worker_primary_thread = threading.Thread(
|
||||
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||
)
|
||||
worker_light_thread = threading.Thread(
|
||||
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||
)
|
||||
worker_docprocessing_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
# Create monitor threads for background workers
|
||||
background_threads = []
|
||||
for name, process in background_processes:
|
||||
threads = []
|
||||
for name, process in processes:
|
||||
thread = threading.Thread(target=monitor_process, args=(name, process))
|
||||
background_threads.append(thread)
|
||||
|
||||
# Start all threads
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_docprocessing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
for thread in background_threads:
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_docprocessing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
for thread in background_threads:
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,5 @@
|
||||
#!/bin/sh
|
||||
# Entrypoint script for supervisord that sets environment variables
|
||||
# for controlling which celery workers to start
|
||||
|
||||
# Default to lightweight mode if not set
|
||||
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
|
||||
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
|
||||
fi
|
||||
|
||||
# Set the complementary variable for supervisord
|
||||
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
|
||||
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="false"
|
||||
else
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="true"
|
||||
fi
|
||||
|
||||
echo "Worker mode configuration:"
|
||||
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
|
||||
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
|
||||
# Entrypoint script for supervisord
|
||||
|
||||
# Launch supervisord with environment variables available
|
||||
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
@@ -39,7 +39,6 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Standard mode: Light worker for fast operations
|
||||
# NOTE: only allowing configuration here and not in the other celery workers,
|
||||
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
|
||||
# user group syncing, deletion, etc.)
|
||||
@@ -54,26 +53,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Lightweight mode: single consolidated background worker
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
|
||||
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
|
||||
[program:celery_worker_background]
|
||||
command=celery -A onyx.background.celery.versioned_apps.background worker
|
||||
--loglevel=INFO
|
||||
--hostname=background@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
|
||||
stdout_logfile=/var/log/celery_worker_background.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
|
||||
|
||||
# Standard mode: separate workers for different background tasks
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
[program:celery_worker_heavy]
|
||||
command=celery -A onyx.background.celery.versioned_apps.heavy worker
|
||||
--loglevel=INFO
|
||||
@@ -85,9 +65,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document processing worker
|
||||
[program:celery_worker_docprocessing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
|
||||
--loglevel=INFO
|
||||
@@ -99,7 +77,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_user_file_processing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
|
||||
@@ -112,9 +89,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document fetching worker
|
||||
[program:celery_worker_docfetching]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
@@ -126,7 +101,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
@@ -139,7 +113,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
@@ -197,7 +170,6 @@ command=tail -qF
|
||||
/var/log/celery_beat.log
|
||||
/var/log/celery_worker_primary.log
|
||||
/var/log/celery_worker_light.log
|
||||
/var/log/celery_worker_background.log
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/celery_worker_docprocessing.log
|
||||
/var/log/celery_worker_monitoring.log
|
||||
|
||||
@@ -11,7 +11,6 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
@@ -20,6 +19,8 @@ from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.server.manage.llm.api import (
|
||||
@@ -122,16 +123,16 @@ class TestLLMConfigurationEndpoint:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
def test_failed_llm_test_raises_http_exception(
|
||||
def test_failed_llm_test_raises_onyx_error(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str, # noqa: ARG002
|
||||
) -> None:
|
||||
"""
|
||||
Test that a failed LLM test raises an HTTPException with status 400.
|
||||
Test that a failed LLM test raises an OnyxError with VALIDATION_ERROR.
|
||||
|
||||
When test_llm returns an error message, the endpoint should raise
|
||||
an HTTPException with the error details.
|
||||
an OnyxError with the error details.
|
||||
"""
|
||||
error_message = "Invalid API key: Authentication failed"
|
||||
|
||||
@@ -143,7 +144,7 @@ class TestLLMConfigurationEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -156,9 +157,8 @@ class TestLLMConfigurationEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Verify the exception details
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -536,11 +536,11 @@ class TestDefaultProviderEndpoint:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
# Now run_test_default_provider should fail
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -581,11 +581,11 @@ class TestDefaultProviderEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -16,13 +16,14 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.api import _mask_string
|
||||
from onyx.server.manage.llm.api import put_llm_provider
|
||||
@@ -100,7 +101,7 @@ class TestLLMProviderChanges:
|
||||
api_base="https://attacker.example.com",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -108,9 +109,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -236,7 +237,7 @@ class TestLLMProviderChanges:
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -244,9 +245,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -339,7 +340,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -347,9 +348,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -375,7 +376,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -383,9 +384,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.detail
|
||||
exc_info.value.message
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1027,6 +1027,13 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self._capture("GET", b"")
|
||||
if self.path == "/health":
|
||||
self._respond_json(200, {"status": "ok"})
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
self._capture("DELETE", b"")
|
||||
self.send_response(200)
|
||||
@@ -1107,6 +1114,14 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
|
||||
"""Ensure the default persona (id=0) has the PythonTool attached."""
|
||||
|
||||
@@ -427,7 +427,7 @@ def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
@@ -673,8 +673,8 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
headers=admin_user.headers,
|
||||
json=base_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
|
||||
|
||||
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
@@ -711,7 +711,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "not currently supported" in response.json()["detail"]
|
||||
assert "not currently supported" in response.json()["message"]
|
||||
|
||||
# Verify no duplicate was created — only the original provider should exist
|
||||
provider = _get_provider_by_id(admin_user, provider_id)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
|
||||
|
||||
# Should return 403 Forbidden
|
||||
assert response.status_code == 403
|
||||
assert "don't have access to this assistant" in response.json()["detail"]
|
||||
assert "don't have access to this assistant" in response.json()["message"]
|
||||
|
||||
|
||||
def test_authorized_persona_access_returns_filtered_providers(
|
||||
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
|
||||
|
||||
# Should return 404
|
||||
assert response.status_code == 404
|
||||
assert "Persona not found" in response.json()["detail"]
|
||||
assert "Persona not found" in response.json()["message"]
|
||||
|
||||
@@ -42,6 +42,78 @@ class NightlyProviderConfig(BaseModel):
|
||||
strict: bool
|
||||
|
||||
|
||||
def _stringify_custom_config_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config: dict[object, object],
|
||||
) -> bool:
|
||||
normalized_keys = {str(key).strip().lower() for key in raw_custom_config}
|
||||
provider_specific_keys = {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}
|
||||
if normalized_keys & provider_specific_keys:
|
||||
return False
|
||||
|
||||
normalized_type = str(raw_custom_config.get("type", "")).strip().lower()
|
||||
if normalized_type not in {"service_account", "external_account"}:
|
||||
return False
|
||||
|
||||
# Service account JSON usually includes private_key/client_email, while external
|
||||
# account JSON includes credential_source. Either shape should be accepted.
|
||||
has_service_account_markers = any(
|
||||
key in normalized_keys for key in {"private_key", "client_email"}
|
||||
)
|
||||
has_external_account_markers = "credential_source" in normalized_keys
|
||||
return has_service_account_markers or has_external_account_markers
|
||||
|
||||
|
||||
def _normalize_custom_config(
|
||||
provider: str, raw_custom_config: dict[object, object]
|
||||
) -> dict[str, str]:
|
||||
if provider == "vertex_ai" and _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config
|
||||
):
|
||||
return {"vertex_credentials": json.dumps(raw_custom_config)}
|
||||
|
||||
normalized: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw_custom_config.items():
|
||||
key = str(raw_key).strip()
|
||||
key_lower = key.lower()
|
||||
|
||||
if provider == "vertex_ai":
|
||||
if key_lower in {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
}:
|
||||
key = "vertex_credentials"
|
||||
elif key_lower in {
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}:
|
||||
key = "vertex_location"
|
||||
|
||||
normalized[key] = _stringify_custom_config_value(raw_value)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
@@ -80,7 +152,9 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
custom_config = _normalize_custom_config(
|
||||
provider=provider, raw_custom_config=parsed
|
||||
)
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
@@ -148,6 +222,23 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "vertex_ai":
|
||||
has_vertex_credentials = bool(
|
||||
config.custom_config and config.custom_config.get("vertex_credentials")
|
||||
)
|
||||
if not has_vertex_credentials:
|
||||
configured_keys = (
|
||||
sorted(config.custom_config.keys()) if config.custom_config else []
|
||||
)
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' "
|
||||
f"for provider '{config.provider}'. "
|
||||
f"Found keys: {configured_keys}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -193,6 +284,7 @@ def _create_provider_payload(
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
@@ -208,24 +300,23 @@ def _create_provider_payload(
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
def _ensure_provider_is_default(
|
||||
provider_id: int, model_name: str, admin_user: DATestUser
|
||||
) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
default_text = list_response.json().get("default_text")
|
||||
assert default_text is not None, "Expected a default provider after setting default"
|
||||
assert default_text.get("provider_id") == provider_id, (
|
||||
f"Expected provider {provider_id} to be default, "
|
||||
f"found {default_text.get('provider_id')}"
|
||||
)
|
||||
assert (
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
default_text.get("model_name") == model_name
|
||||
), f"Expected default model {model_name}, found {default_text.get('model_name')}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
@@ -326,8 +417,9 @@ def _create_and_test_provider_for_model(
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={"provider_id": provider_id, "model_name": model_name},
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
@@ -335,7 +427,9 @@ def _create_and_test_provider_for_model(
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_ensure_provider_is_default(
|
||||
provider_id=provider_id, model_name=model_name, admin_user=admin_user
|
||||
)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def _upload_connector_file(
|
||||
*,
|
||||
user_performing_action: DATestUser,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
) -> tuple[str, str]:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
|
||||
files=[("files", (file_name, io.BytesIO(content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return payload["file_paths"][0], payload["file_names"][0]
|
||||
|
||||
|
||||
def _update_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
file_ids_to_remove: list[str],
|
||||
new_file_name: str,
|
||||
new_file_content: bytes,
|
||||
) -> requests.Response:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files/update",
|
||||
data={"file_ids_to_remove": json.dumps(file_ids_to_remove)},
|
||||
files=[("files", (new_file_name, io.BytesIO(new_file_content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def _list_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
@pytest.mark.usefixtures("reset")
|
||||
def test_only_global_curator_can_update_public_file_connector_files() -> None:
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
global_curator_creator = UserManager.create(name="global_curator_creator")
|
||||
global_curator_creator = UserManager.set_role(
|
||||
user_to_set=global_curator_creator,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
global_curator_editor = UserManager.create(name="global_curator_editor")
|
||||
global_curator_editor = UserManager.set_role(
|
||||
user_to_set=global_curator_editor,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
curator_user = UserManager.create(name="curator_user")
|
||||
curator_group = UserGroupManager.create(
|
||||
name="curator_group",
|
||||
user_ids=[curator_user.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[curator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=curator_group,
|
||||
user_to_set_as_curator=curator_user,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
initial_file_id, initial_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="initial-file.txt",
|
||||
content=b"initial file content",
|
||||
)
|
||||
|
||||
connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="public_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [initial_file_id],
|
||||
"file_names": [initial_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name="public_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
name="public_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
)
|
||||
curator_list_response.raise_for_status()
|
||||
curator_list_payload = curator_list_response.json()
|
||||
assert any(f["file_id"] == initial_file_id for f in curator_list_payload["files"])
|
||||
|
||||
global_curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
)
|
||||
global_curator_list_response.raise_for_status()
|
||||
global_curator_list_payload = global_curator_list_response.json()
|
||||
assert any(
|
||||
f["file_id"] == initial_file_id for f in global_curator_list_payload["files"]
|
||||
)
|
||||
|
||||
denied_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="curator-file.txt",
|
||||
new_file_content=b"curator updated file",
|
||||
)
|
||||
assert denied_response.status_code == 403
|
||||
|
||||
allowed_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="global-curator-file.txt",
|
||||
new_file_content=b"global curator updated file",
|
||||
)
|
||||
allowed_response.raise_for_status()
|
||||
|
||||
payload = allowed_response.json()
|
||||
assert initial_file_id not in payload["file_paths"]
|
||||
assert "global-curator-file.txt" in payload["file_names"]
|
||||
|
||||
creator_group = UserGroupManager.create(
|
||||
name="creator_group",
|
||||
user_ids=[global_curator_creator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[creator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
private_file_id, private_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="private-initial-file.txt",
|
||||
content=b"private initial file content",
|
||||
)
|
||||
|
||||
private_connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="private_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [private_file_id],
|
||||
"file_names": [private_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
)
|
||||
private_credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=private_connector.id,
|
||||
credential_id=private_credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
private_denied_response = _update_connector_files(
|
||||
connector_id=private_connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[private_file_id],
|
||||
new_file_name="global-curator-private-file.txt",
|
||||
new_file_content=b"global curator private update",
|
||||
)
|
||||
assert private_denied_response.status_code == 403
|
||||
@@ -11,7 +11,8 @@ from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
class TestCreateCheckoutSession:
|
||||
@@ -88,22 +89,25 @@ class TestCreateCheckoutSession:
|
||||
mock_get_tenant: MagicMock,
|
||||
mock_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Should raise HTTPException when service fails."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
"""Should propagate OnyxError when service fails."""
|
||||
from ee.onyx.server.billing.api import create_checkout_session
|
||||
|
||||
mock_get_license.return_value = None
|
||||
mock_get_tenant.return_value = "tenant_123"
|
||||
mock_service.side_effect = BillingServiceError("Stripe error", 502)
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Stripe error",
|
||||
status_code_override=502,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await create_checkout_session(
|
||||
request=None, _=MagicMock(), db_session=MagicMock()
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert "Stripe error" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Stripe error"
|
||||
|
||||
|
||||
class TestCreateCustomerPortalSession:
|
||||
@@ -121,20 +125,19 @@ class TestCreateCustomerPortalSession:
|
||||
mock_service: AsyncMock, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Should reject self-hosted without license."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import create_customer_portal_session
|
||||
|
||||
mock_get_license.return_value = None
|
||||
mock_get_tenant.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await create_customer_portal_session(
|
||||
request=None, _=MagicMock(), db_session=MagicMock()
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No license found" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.create_portal_service")
|
||||
@@ -227,8 +230,6 @@ class TestUpdateSeats:
|
||||
mock_get_tenant: MagicMock,
|
||||
) -> None:
|
||||
"""Should reject self-hosted without license."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import update_seats
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
|
||||
@@ -237,11 +238,12 @@ class TestUpdateSeats:
|
||||
|
||||
request = SeatUpdateRequest(new_seat_count=10)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No license found" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.get_used_seats")
|
||||
@@ -295,26 +297,27 @@ class TestUpdateSeats:
|
||||
mock_service: AsyncMock,
|
||||
mock_get_used_seats: MagicMock,
|
||||
) -> None:
|
||||
"""Should convert BillingServiceError to HTTPException."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
"""Should propagate OnyxError from service layer."""
|
||||
from ee.onyx.server.billing.api import update_seats
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_get_used_seats.return_value = 0
|
||||
mock_service.side_effect = BillingServiceError(
|
||||
"Cannot reduce below 10 seats", 400
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Cannot reduce below 10 seats",
|
||||
status_code_override=400,
|
||||
)
|
||||
|
||||
request = SeatUpdateRequest(new_seat_count=5)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Cannot reduce below 10 seats" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Cannot reduce below 10 seats"
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@@ -332,19 +335,18 @@ class TestCircuitBreaker:
|
||||
mock_circuit_open: MagicMock,
|
||||
) -> None:
|
||||
"""Should return 503 when circuit breaker is open."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open.return_value = True
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "Connect to Stripe" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SERVICE_UNAVAILABLE
|
||||
assert "Connect to Stripe" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)
|
||||
@@ -362,16 +364,18 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 502 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = BillingServiceError("Connection failed", 502)
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Connection failed",
|
||||
status_code_override=502,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
@@ -393,16 +397,18 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 503 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = BillingServiceError("Service unavailable", 503)
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Service unavailable",
|
||||
status_code_override=503,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
@@ -424,16 +430,18 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 504 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = BillingServiceError("Gateway timeout", 504)
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Gateway timeout",
|
||||
status_code_override=504,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 504
|
||||
@@ -455,16 +463,18 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should NOT open circuit breaker on 400 error (client error)."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = BillingServiceError("Bad request", 400)
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Bad request",
|
||||
status_code_override=400,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
@@ -14,7 +14,8 @@ from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
class TestMakeBillingRequest:
|
||||
@@ -78,7 +79,7 @@ class TestMakeBillingRequest:
|
||||
mock_base_url: MagicMock,
|
||||
mock_headers: MagicMock,
|
||||
) -> None:
|
||||
"""Should raise BillingServiceError on HTTP error."""
|
||||
"""Should raise OnyxError on HTTP error."""
|
||||
from ee.onyx.server.billing.service import _make_billing_request
|
||||
|
||||
mock_base_url.return_value = "https://api.example.com"
|
||||
@@ -91,7 +92,7 @@ class TestMakeBillingRequest:
|
||||
mock_client = make_mock_http_client("post", side_effect=error)
|
||||
|
||||
with patch("httpx.AsyncClient", mock_client):
|
||||
with pytest.raises(BillingServiceError) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await _make_billing_request(
|
||||
method="POST",
|
||||
path="/test",
|
||||
@@ -99,6 +100,7 @@ class TestMakeBillingRequest:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Bad request" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -136,7 +138,7 @@ class TestMakeBillingRequest:
|
||||
mock_base_url: MagicMock,
|
||||
mock_headers: MagicMock,
|
||||
) -> None:
|
||||
"""Should raise BillingServiceError on connection error."""
|
||||
"""Should raise OnyxError on connection error."""
|
||||
from ee.onyx.server.billing.service import _make_billing_request
|
||||
|
||||
mock_base_url.return_value = "https://api.example.com"
|
||||
@@ -145,10 +147,11 @@ class TestMakeBillingRequest:
|
||||
mock_client = make_mock_http_client("post", side_effect=error)
|
||||
|
||||
with patch("httpx.AsyncClient", mock_client):
|
||||
with pytest.raises(BillingServiceError) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await _make_billing_request(method="POST", path="/test")
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Failed to connect" in exc_info.value.message
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from unittest.mock import patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
class TestGetStripePublishableKey:
|
||||
"""Tests for get_stripe_publishable_key endpoint."""
|
||||
@@ -62,15 +65,14 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_rejects_invalid_env_var_key_format(self) -> None:
|
||||
"""Should reject keys that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -80,8 +82,6 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_rejects_invalid_s3_key_format(self) -> None:
|
||||
"""Should reject keys from S3 that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
mock_response = MagicMock()
|
||||
@@ -92,11 +92,12 @@ class TestGetStripePublishableKey:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -106,34 +107,32 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_handles_s3_fetch_error(self) -> None:
|
||||
"""Should return error when S3 fetch fails."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
side_effect=httpx.HTTPError("Connection failed")
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Failed to fetch Stripe publishable key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL", None)
|
||||
async def test_error_when_no_config(self) -> None:
|
||||
"""Should return error when neither env var nor S3 URL is configured."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "not configured" in exc_info.value.detail
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert "not configured" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
||||
178
backend/tests/unit/onyx/chat/test_save_chat_files.py
Normal file
178
backend/tests/unit/onyx/chat/test_save_chat_files.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for _extract_referenced_file_descriptors in save_chat.py.
|
||||
|
||||
Verifies that only code interpreter generated files actually referenced
|
||||
in the assistant's message text are extracted as FileDescriptors for
|
||||
cross-turn persistence.
|
||||
"""
|
||||
|
||||
from onyx.chat.save_chat import _extract_referenced_file_descriptors
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
|
||||
|
||||
def _make_tool_call_info(
|
||||
generated_files: list[PythonExecutionFile] | None = None,
|
||||
tool_name: str = "python",
|
||||
) -> ToolCallInfo:
|
||||
return ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=0,
|
||||
tab_index=0,
|
||||
tool_name=tool_name,
|
||||
tool_call_id="tc_1",
|
||||
tool_id=1,
|
||||
reasoning_tokens=None,
|
||||
tool_call_arguments={"code": "print('hi')"},
|
||||
tool_call_response="{}",
|
||||
generated_files=generated_files,
|
||||
)
|
||||
|
||||
|
||||
def test_returns_empty_when_no_generated_files() -> None:
|
||||
tool_call = _make_tool_call_info(generated_files=None)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "some message")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_returns_empty_when_file_not_referenced() -> None:
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link="http://localhost/api/chat/file/abc-123",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "Here is your answer.")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extracts_referenced_file() -> None:
|
||||
file_id = "abc-123-def"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = (
|
||||
f"Here is the chart: [chart.png](http://localhost/api/chat/file/{file_id})"
|
||||
)
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
assert result[0]["type"] == ChatFileType.IMAGE
|
||||
assert result[0]["name"] == "chart.png"
|
||||
|
||||
|
||||
def test_filters_unreferenced_files() -> None:
|
||||
referenced_id = "ref-111"
|
||||
unreferenced_id = "unref-222"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link=f"http://localhost/api/chat/file/{referenced_id}",
|
||||
),
|
||||
PythonExecutionFile(
|
||||
filename="data.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{unreferenced_id}",
|
||||
),
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"Here is the chart: [chart.png](http://localhost/api/chat/file/{referenced_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == referenced_id
|
||||
assert result[0]["name"] == "chart.png"
|
||||
|
||||
|
||||
def test_extracts_from_multiple_tool_calls() -> None:
|
||||
id_1 = "file-aaa"
|
||||
id_2 = "file-bbb"
|
||||
tc1 = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="plot.png",
|
||||
file_link=f"http://localhost/api/chat/file/{id_1}",
|
||||
)
|
||||
]
|
||||
)
|
||||
tc2 = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="report.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{id_2}",
|
||||
)
|
||||
]
|
||||
)
|
||||
message = (
|
||||
f"[plot.png](http://localhost/api/chat/file/{id_1}) "
|
||||
f"and [report.csv](http://localhost/api/chat/file/{id_2})"
|
||||
)
|
||||
|
||||
result = _extract_referenced_file_descriptors([tc1, tc2], message)
|
||||
|
||||
assert len(result) == 2
|
||||
ids = {d["id"] for d in result}
|
||||
assert ids == {id_1, id_2}
|
||||
|
||||
|
||||
def test_csv_file_type() -> None:
|
||||
file_id = "csv-123"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="data.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"[data.csv](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.CSV
|
||||
|
||||
|
||||
def test_unknown_extension_defaults_to_plain_text() -> None:
|
||||
file_id = "bin-456"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="output.xyz",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"[output.xyz](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.PLAIN_TEXT
|
||||
|
||||
|
||||
def test_skips_tool_calls_without_generated_files() -> None:
|
||||
file_id = "img-789"
|
||||
tc_no_files = _make_tool_call_info(generated_files=None)
|
||||
tc_empty = _make_tool_call_info(generated_files=[])
|
||||
tc_with_files = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="result.png",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
)
|
||||
message = f"[result.png](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors(
|
||||
[tc_no_files, tc_empty, tc_with_files], message
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
0
backend/tests/unit/onyx/error_handling/__init__.py
Normal file
0
backend/tests/unit/onyx/error_handling/__init__.py
Normal file
90
backend/tests/unit/onyx/error_handling/test_exceptions.py
Normal file
90
backend/tests/unit/onyx/error_handling/test_exceptions.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for OnyxError and the global exception handler."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
|
||||
|
||||
class TestOnyxError:
|
||||
"""Unit tests for OnyxError construction and properties."""
|
||||
|
||||
def test_basic_construction(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
assert err.error_code is OnyxErrorCode.NOT_FOUND
|
||||
assert err.message == "Session not found"
|
||||
assert err.status_code == 404
|
||||
|
||||
def test_message_defaults_to_code(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
assert err.message == "UNAUTHENTICATED"
|
||||
assert str(err) == "UNAUTHENTICATED"
|
||||
|
||||
def test_status_code_override(self) -> None:
|
||||
err = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"upstream failed",
|
||||
status_code_override=503,
|
||||
)
|
||||
assert err.status_code == 503
|
||||
# error_code still reports its own default
|
||||
assert err.error_code.status_code == 502
|
||||
|
||||
def test_no_override_uses_error_code_status(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.RATE_LIMITED, "slow down")
|
||||
assert err.status_code == 429
|
||||
|
||||
def test_is_exception(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.INTERNAL_ERROR)
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
|
||||
class TestExceptionHandler:
|
||||
"""Integration test: OnyxError → JSON response via FastAPI TestClient."""
|
||||
|
||||
@pytest.fixture()
|
||||
def client(self) -> TestClient:
|
||||
app = FastAPI()
|
||||
register_onyx_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
def _boom() -> None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Thing not found")
|
||||
|
||||
@app.get("/boom-override")
|
||||
def _boom_override() -> None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"upstream 503",
|
||||
status_code_override=503,
|
||||
)
|
||||
|
||||
@app.get("/boom-default-msg")
|
||||
def _boom_default() -> None:
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_returns_correct_status_and_body(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "NOT_FOUND"
|
||||
assert body["message"] == "Thing not found"
|
||||
|
||||
def test_status_code_override_in_response(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-override")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "BAD_GATEWAY"
|
||||
assert body["message"] == "upstream 503"
|
||||
|
||||
def test_default_message(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-default-msg")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "UNAUTHENTICATED"
|
||||
assert body["message"] == "UNAUTHENTICATED"
|
||||
@@ -104,3 +104,102 @@ def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
"""Tests for PythonTool availability based on server_enabled flag and health check.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
- CodeInterpreterServer.server_enabled is False in the database, or
|
||||
- The Code Interpreter service health check fails.
|
||||
|
||||
Also verifies that the health check result is cached with a TTL.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
|
||||
CLIENT_MODULE = "onyx.tools.tool_implementations.python.code_interpreter_client"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", None)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -27,10 +39,7 @@ def test_python_tool_unavailable_without_base_url() -> None:
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "")
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -43,13 +52,8 @@ def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
@@ -64,18 +68,15 @@ def test_python_tool_unavailable_when_server_disabled(
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Available when both conditions are met
|
||||
# Health check determines availability when URL + server are OK
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_available_when_health_check_passes(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
@@ -84,5 +85,122 @@ def test_python_tool_available_when_server_enabled(
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_unavailable_when_health_check_fails(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check is NOT reached when preconditions fail
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_health_check_not_called_when_server_disabled(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client_cls.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check caching (tested at the client level)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_health_check_cached_on_second_call() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health(use_cache=True) is True
|
||||
assert client.health(use_cache=True) is True
|
||||
# Only one HTTP call — the second used the cache
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{CLIENT_MODULE}.time")
|
||||
def test_health_check_refreshed_after_ttl_expires(mock_time: MagicMock) -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
_HEALTH_CACHE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
# First call at t=0 — cache miss
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Second call within TTL — cache hit
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS - 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Third call after TTL — cache miss, fresh request
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS + 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_health_check_no_cache_by_default() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health() is True
|
||||
assert client.health() is True
|
||||
# Both calls hit the network when use_cache=False (default)
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
_normalize_queries_input,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
|
||||
|
||||
def _make_result(
|
||||
title: str = "Title", link: str = "https://example.com"
|
||||
) -> WebSearchResult:
|
||||
return WebSearchResult(title=title, link=link, snippet="snippet")
|
||||
|
||||
|
||||
def _make_tool(mock_provider: Any) -> WebSearchTool:
|
||||
"""Instantiate WebSearchTool with all DB/provider deps mocked out."""
|
||||
provider_model = MagicMock()
|
||||
provider_model.provider_type = "brave"
|
||||
provider_model.api_key = MagicMock()
|
||||
provider_model.api_key.get_value.return_value = "fake-key"
|
||||
provider_model.config = {}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_session_with_current_tenant"
|
||||
) as mock_session_ctx,
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.fetch_active_web_search_provider",
|
||||
return_value=provider_model,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.build_search_provider_from_config",
|
||||
return_value=mock_provider,
|
||||
),
|
||||
):
|
||||
mock_session_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
tool = WebSearchTool(tool_id=1, emitter=MagicMock())
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
def _run(tool: WebSearchTool, queries: Any) -> list[str]:
|
||||
"""Call tool.run() and return the list of query strings passed to provider.search."""
|
||||
placement = Placement(turn_index=0, tab_index=0)
|
||||
override_kwargs = WebSearchToolOverrideKwargs(starting_citation_num=1)
|
||||
tool.run(placement=placement, override_kwargs=override_kwargs, queries=queries)
|
||||
search_mock = cast(MagicMock, tool._provider.search) # noqa: SLF001
|
||||
return [call.args[0] for call in search_mock.call_args_list]
|
||||
|
||||
|
||||
class TestNormalizeQueriesInput:
|
||||
"""Unit tests for _normalize_queries_input (coercion + sanitization)."""
|
||||
|
||||
def test_bare_string_returns_single_element_list(self) -> None:
|
||||
assert _normalize_queries_input("hello") == ["hello"]
|
||||
|
||||
def test_bare_string_stripped_and_sanitized(self) -> None:
|
||||
assert _normalize_queries_input(" hello ") == ["hello"]
|
||||
# Control chars (e.g. null) removed; no space inserted
|
||||
assert _normalize_queries_input("hello\x00world") == ["helloworld"]
|
||||
|
||||
def test_empty_string_returns_empty_list(self) -> None:
|
||||
assert _normalize_queries_input("") == []
|
||||
assert _normalize_queries_input(" ") == []
|
||||
|
||||
def test_list_of_strings_returned_sanitized(self) -> None:
|
||||
assert _normalize_queries_input(["a", "b"]) == ["a", "b"]
|
||||
# Leading/trailing space stripped; control chars (e.g. tab) removed
|
||||
assert _normalize_queries_input([" a ", "b\tb"]) == ["a", "bb"]
|
||||
|
||||
def test_list_none_skipped(self) -> None:
|
||||
assert _normalize_queries_input(["a", None, "b"]) == ["a", "b"]
|
||||
|
||||
def test_list_non_string_coerced(self) -> None:
|
||||
assert _normalize_queries_input([1, "two"]) == ["1", "two"]
|
||||
|
||||
def test_list_whitespace_only_dropped(self) -> None:
|
||||
assert _normalize_queries_input(["a", "", " ", "b"]) == ["a", "b"]
|
||||
|
||||
def test_non_list_non_string_returns_empty_list(self) -> None:
|
||||
assert _normalize_queries_input(42) == []
|
||||
assert _normalize_queries_input({}) == []
|
||||
|
||||
|
||||
class TestWebSearchToolRunQueryCoercion:
|
||||
def test_list_of_strings_dispatches_each_query(self) -> None:
|
||||
"""Normal case: list of queries → one search call per query."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, ["python decorators", "python generators"])
|
||||
|
||||
# run_functions_tuples_in_parallel uses a thread pool; call_args_list order is non-deterministic.
|
||||
assert sorted(dispatched) == ["python decorators", "python generators"]
|
||||
|
||||
def test_bare_string_dispatches_as_single_query(self) -> None:
|
||||
"""LLM returns a bare string instead of an array — must NOT be split char-by-char."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, "what is the capital of France")
|
||||
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0] == "what is the capital of France"
|
||||
|
||||
def test_bare_string_does_not_search_individual_characters(self) -> None:
|
||||
"""Regression: single-char searches must not occur."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, "hi")
|
||||
for query_arg in dispatched:
|
||||
assert (
|
||||
len(query_arg) > 1
|
||||
), f"Single-character query dispatched: {query_arg!r}"
|
||||
|
||||
def test_control_characters_sanitized_before_dispatch(self) -> None:
|
||||
"""Queries with control chars have those chars removed before dispatch."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, ["foo\x00bar", "baz\tbaz"])
|
||||
|
||||
# run_functions_tuples_in_parallel uses a thread pool; call_args_list is in
|
||||
# execution order, not submission order, so compare in sorted order.
|
||||
assert sorted(dispatched) == ["bazbaz", "foobar"]
|
||||
|
||||
def test_all_empty_or_whitespace_raises_tool_call_exception(self) -> None:
|
||||
"""When normalization yields no valid queries, run() raises ToolCallException."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
placement = Placement(turn_index=0, tab_index=0)
|
||||
override_kwargs = WebSearchToolOverrideKwargs(starting_citation_num=1)
|
||||
|
||||
with pytest.raises(ToolCallException) as exc_info:
|
||||
tool.run(
|
||||
placement=placement,
|
||||
override_kwargs=override_kwargs,
|
||||
queries=" ",
|
||||
)
|
||||
|
||||
assert "No valid" in str(exc_info.value)
|
||||
cast(MagicMock, mock_provider.search).assert_not_called()
|
||||
@@ -138,7 +138,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
# =============================================================================
|
||||
# ONYX NO-VECTOR-DB OVERLAY
|
||||
# ONYX LITE — MINIMAL DEPLOYMENT OVERLAY
|
||||
# =============================================================================
|
||||
# Overlay to run Onyx without a vector database (Vespa), model servers, or
|
||||
# code interpreter. In this mode, connectors and RAG search are disabled, but
|
||||
# the core chat experience (LLM conversations, tools, user file uploads,
|
||||
# Projects, Agent knowledge) still works.
|
||||
# Overlay to run Onyx in a minimal configuration: no vector database (Vespa),
|
||||
# no Redis, no model servers, and no background workers. Only PostgreSQL is
|
||||
# required. In this mode, connectors and RAG search are disabled, but the core
|
||||
# chat experience (LLM conversations, tools, user file uploads, Projects,
|
||||
# Agent knowledge, code interpreter) still works.
|
||||
#
|
||||
# Usage:
|
||||
# docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml up -d
|
||||
# docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml up -d
|
||||
#
|
||||
# With dev ports:
|
||||
# docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml \
|
||||
# docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml \
|
||||
# -f docker-compose.dev.yml up -d --wait
|
||||
#
|
||||
# This overlay:
|
||||
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
|
||||
# so they do not start by default
|
||||
# - Moves the background worker to the "background" profile (the API server
|
||||
# handles all background work via FastAPI BackgroundTasks)
|
||||
# - Makes the depends_on references to removed services optional
|
||||
# - Moves Vespa (index), both model servers, code-interpreter, Redis (cache),
|
||||
# and the background worker to profiles so they do not start by default
|
||||
# - Makes depends_on references to removed services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on the api_server
|
||||
# - Uses PostgreSQL for caching and auth instead of Redis
|
||||
# - Uses PostgreSQL for file storage instead of S3/MinIO
|
||||
#
|
||||
# To selectively bring services back:
|
||||
# --profile vectordb Vespa + indexing model server
|
||||
# --profile inference Inference model server
|
||||
# --profile background Background worker (Celery)
|
||||
# --profile background Background worker (Celery) — also needs redis
|
||||
# --profile redis Redis cache
|
||||
# --profile code-interpreter Code interpreter
|
||||
# =============================================================================
|
||||
|
||||
@@ -36,6 +38,9 @@ services:
|
||||
index:
|
||||
condition: service_started
|
||||
required: false
|
||||
cache:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
@@ -45,9 +50,11 @@ services:
|
||||
environment:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
- CACHE_BACKEND=postgres
|
||||
- AUTH_BACKEND=postgres
|
||||
|
||||
# Move the background worker to a profile so it does not start by default.
|
||||
# The API server handles all background work in NO_VECTOR_DB mode.
|
||||
# The API server handles all background work in lite mode.
|
||||
background:
|
||||
profiles: ["background"]
|
||||
depends_on:
|
||||
@@ -61,6 +68,11 @@ services:
|
||||
condition: service_started
|
||||
required: false
|
||||
|
||||
# Move Redis to a profile so it does not start by default.
|
||||
# The Postgres cache backend replaces Redis in lite mode.
|
||||
cache:
|
||||
profiles: ["redis"]
|
||||
|
||||
# Move Vespa and indexing model server to a profile so they do not start.
|
||||
index:
|
||||
profiles: ["vectordb"]
|
||||
@@ -52,7 +52,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -65,7 +65,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -70,7 +70,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -58,7 +58,6 @@ services:
|
||||
env_file:
|
||||
- .env_eval
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=disabled
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -146,7 +146,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- FILE_STORE_BACKEND=${FILE_STORE_BACKEND:-s3}
|
||||
- POSTGRES_HOST=${POSTGRES_HOST:-relational_db}
|
||||
- VESPA_HOST=${VESPA_HOST:-index}
|
||||
|
||||
31
deployment/helm/charts/onyx/values-lite.yaml
Normal file
31
deployment/helm/charts/onyx/values-lite.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
# =============================================================================
|
||||
# ONYX LITE — MINIMAL DEPLOYMENT VALUES
|
||||
# =============================================================================
|
||||
# Minimal Onyx deployment: no vector database, no Redis, no model servers.
|
||||
# Only PostgreSQL is required. Connectors and RAG search are disabled, but the
|
||||
# core chat experience (LLM conversations, tools, user file uploads, Projects,
|
||||
# Agent knowledge) still works.
|
||||
#
|
||||
# Usage:
|
||||
# helm install onyx ./deployment/helm/charts/onyx \
|
||||
# -f ./deployment/helm/charts/onyx/values-lite.yaml
|
||||
#
|
||||
# Or merged with your own overrides:
|
||||
# helm install onyx ./deployment/helm/charts/onyx \
|
||||
# -f ./deployment/helm/charts/onyx/values-lite.yaml \
|
||||
# -f my-overrides.yaml
|
||||
# =============================================================================
|
||||
|
||||
vectorDB:
|
||||
enabled: false
|
||||
|
||||
vespa:
|
||||
enabled: false
|
||||
|
||||
redis:
|
||||
enabled: false
|
||||
|
||||
configMap:
|
||||
CACHE_BACKEND: "postgres"
|
||||
AUTH_BACKEND: "postgres"
|
||||
FILE_STORE_BACKEND: "postgres"
|
||||
@@ -14,30 +14,32 @@ Built with [Tauri](https://tauri.app) for minimal bundle size (~10MB vs Electron
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Action |
|
||||
|----------|--------|
|
||||
| `⌘ N` | New Chat |
|
||||
| `⌘ ⇧ N` | New Window |
|
||||
| `⌘ R` | Reload |
|
||||
| `⌘ [` | Go Back |
|
||||
| `⌘ ]` | Go Forward |
|
||||
| `⌘ ,` | Open Config File |
|
||||
| `⌘ W` | Close Window |
|
||||
| `⌘ Q` | Quit |
|
||||
| Shortcut | Action |
|
||||
| -------- | ---------------- |
|
||||
| `⌘ N` | New Chat |
|
||||
| `⌘ ⇧ N` | New Window |
|
||||
| `⌘ R` | Reload |
|
||||
| `⌘ [` | Go Back |
|
||||
| `⌘ ]` | Go Forward |
|
||||
| `⌘ ,` | Open Config File |
|
||||
| `⌘ W` | Close Window |
|
||||
| `⌘ Q` | Quit |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Rust** (latest stable)
|
||||
|
||||
```bash
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
source $HOME/.cargo/env
|
||||
```
|
||||
|
||||
2. **Node.js** (18+)
|
||||
|
||||
```bash
|
||||
# Using homebrew
|
||||
brew install node
|
||||
|
||||
|
||||
# Or using nvm
|
||||
nvm install 18
|
||||
```
|
||||
@@ -55,16 +57,21 @@ npm install
|
||||
|
||||
# Run in development mode
|
||||
npm run dev
|
||||
|
||||
# Run in debug mode
|
||||
npm run debug
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Build for current architecture
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
### Build Universal Binary (Intel + Apple Silicon)
|
||||
|
||||
```bash
|
||||
# First, add the targets
|
||||
rustup target add x86_64-apple-darwin
|
||||
@@ -103,6 +110,7 @@ Before building, add your app icons to `src-tauri/icons/`:
|
||||
- `icon.ico` (Windows, optional)
|
||||
|
||||
You can generate these from a 1024x1024 source image using:
|
||||
|
||||
```bash
|
||||
# Using tauri's icon generator
|
||||
npm run tauri icon path/to/your-icon.png
|
||||
@@ -115,6 +123,7 @@ npm run tauri icon path/to/your-icon.png
|
||||
The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
|
||||
|
||||
**Config file location:**
|
||||
|
||||
- macOS: `~/Library/Application Support/app.onyx.desktop/config.json`
|
||||
- Linux: `~/.config/app.onyx.desktop/config.json`
|
||||
- Windows: `%APPDATA%/app.onyx.desktop/config.json`
|
||||
@@ -135,6 +144,7 @@ The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
|
||||
4. Restart the app
|
||||
|
||||
**Quick edit via terminal:**
|
||||
|
||||
```bash
|
||||
# macOS
|
||||
open -t ~/Library/Application\ Support/app.onyx.desktop/config.json
|
||||
@@ -146,6 +156,7 @@ code ~/Library/Application\ Support/app.onyx.desktop/config.json
|
||||
### Change the default URL in build
|
||||
|
||||
Edit `src-tauri/tauri.conf.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"app": {
|
||||
@@ -165,6 +176,7 @@ Edit `src-tauri/src/main.rs` in the `setup_shortcuts` function.
|
||||
### Window appearance
|
||||
|
||||
Modify the window configuration in `src-tauri/tauri.conf.json`:
|
||||
|
||||
- `titleBarStyle`: `"Overlay"` (macOS native) or `"Visible"`
|
||||
- `decorations`: Window chrome
|
||||
- `transparent`: For custom backgrounds
|
||||
@@ -172,16 +184,20 @@ Modify the window configuration in `src-tauri/tauri.conf.json`:
|
||||
## Troubleshooting
|
||||
|
||||
### "Unable to resolve host"
|
||||
|
||||
Make sure you have an internet connection. The app loads content from `cloud.onyx.app`.
|
||||
|
||||
### Build fails on M1/M2 Mac
|
||||
|
||||
```bash
|
||||
# Ensure you have the right target
|
||||
rustup target add aarch64-apple-darwin
|
||||
```
|
||||
|
||||
### Code signing for distribution
|
||||
|
||||
For distributing outside the App Store, you'll need to:
|
||||
|
||||
1. Get an Apple Developer certificate
|
||||
2. Sign the app: `codesign --deep --force --sign "Developer ID" target/release/bundle/macos/Onyx.app`
|
||||
3. Notarize with Apple
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"description": "Lightweight desktop app for Onyx Cloud",
|
||||
"scripts": {
|
||||
"dev": "tauri dev",
|
||||
"debug": "tauri dev -- -- --debug",
|
||||
"build": "tauri build",
|
||||
"build:dmg": "tauri build --target universal-apple-darwin",
|
||||
"build:linux": "tauri build --bundles deb,rpm"
|
||||
|
||||
@@ -23,3 +23,4 @@ url = "2.5"
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
custom-protocol = ["tauri/custom-protocol"]
|
||||
devtools = ["tauri/devtools"]
|
||||
|
||||
@@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
use std::io::Write as IoWrite;
|
||||
use std::time::SystemTime;
|
||||
#[cfg(target_os = "macos")]
|
||||
use std::time::Duration;
|
||||
use tauri::image::Image;
|
||||
@@ -230,6 +232,63 @@ const MENU_KEY_HANDLER_SCRIPT: &str = r#"
|
||||
})();
|
||||
"#;
|
||||
|
||||
const CONSOLE_CAPTURE_SCRIPT: &str = r#"
|
||||
(() => {
|
||||
if (window.__ONYX_CONSOLE_CAPTURE__) return;
|
||||
window.__ONYX_CONSOLE_CAPTURE__ = true;
|
||||
|
||||
const levels = ['log', 'warn', 'error', 'info', 'debug'];
|
||||
const originals = {};
|
||||
|
||||
levels.forEach(level => {
|
||||
originals[level] = console[level];
|
||||
console[level] = function(...args) {
|
||||
originals[level].apply(console, args);
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
const message = args.map(a => {
|
||||
try { return typeof a === 'string' ? a : JSON.stringify(a); }
|
||||
catch { return String(a); }
|
||||
}).join(' ');
|
||||
invoke('log_from_frontend', { level, message });
|
||||
}
|
||||
} catch {}
|
||||
};
|
||||
});
|
||||
|
||||
window.addEventListener('error', (event) => {
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
invoke('log_from_frontend', {
|
||||
level: 'error',
|
||||
message: `[uncaught] ${event.message} at ${event.filename}:${event.lineno}:${event.colno}`
|
||||
});
|
||||
}
|
||||
} catch {}
|
||||
});
|
||||
|
||||
window.addEventListener('unhandledrejection', (event) => {
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
invoke('log_from_frontend', {
|
||||
level: 'error',
|
||||
message: `[unhandled rejection] ${event.reason}`
|
||||
});
|
||||
}
|
||||
} catch {}
|
||||
});
|
||||
})();
|
||||
"#;
|
||||
|
||||
const MENU_TOGGLE_DEVTOOLS_ID: &str = "toggle_devtools";
|
||||
const MENU_OPEN_DEBUG_LOG_ID: &str = "open_debug_log";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub server_url: String,
|
||||
@@ -311,12 +370,87 @@ fn save_config(config: &AppConfig) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Debug Mode
|
||||
// ============================================================================
|
||||
|
||||
fn is_debug_mode() -> bool {
|
||||
std::env::args().any(|arg| arg == "--debug") || std::env::var("ONYX_DEBUG").is_ok()
|
||||
}
|
||||
|
||||
fn get_debug_log_path() -> Option<PathBuf> {
|
||||
get_config_dir().map(|dir| dir.join("frontend_debug.log"))
|
||||
}
|
||||
|
||||
fn init_debug_log_file() -> Option<fs::File> {
|
||||
let log_path = get_debug_log_path()?;
|
||||
if let Some(parent) = log_path.parent() {
|
||||
let _ = fs::create_dir_all(parent);
|
||||
}
|
||||
fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&log_path)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn format_utc_timestamp() -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
let total_secs = now.as_secs();
|
||||
let millis = now.subsec_millis();
|
||||
|
||||
let days = total_secs / 86400;
|
||||
let secs_of_day = total_secs % 86400;
|
||||
let hours = secs_of_day / 3600;
|
||||
let mins = (secs_of_day % 3600) / 60;
|
||||
let secs = secs_of_day % 60;
|
||||
|
||||
// Days since Unix epoch -> Y/M/D via civil calendar arithmetic
|
||||
let z = days as i64 + 719468;
|
||||
let era = z / 146097;
|
||||
let doe = z - era * 146097;
|
||||
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||
let y = yoe + era * 400;
|
||||
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||
let mp = (5 * doy + 2) / 153;
|
||||
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||
let y = if m <= 2 { y + 1 } else { y };
|
||||
|
||||
format!(
|
||||
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
|
||||
y, m, d, hours, mins, secs, millis
|
||||
)
|
||||
}
|
||||
|
||||
fn inject_console_capture(webview: &Webview) {
|
||||
let _ = webview.eval(CONSOLE_CAPTURE_SCRIPT);
|
||||
}
|
||||
|
||||
fn maybe_open_devtools(app: &AppHandle, window: &tauri::WebviewWindow) {
|
||||
#[cfg(any(debug_assertions, feature = "devtools"))]
|
||||
{
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
window.open_devtools();
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(debug_assertions, feature = "devtools")))]
|
||||
{
|
||||
let _ = (app, window);
|
||||
}
|
||||
}
|
||||
|
||||
// Global config state
|
||||
struct ConfigState {
|
||||
config: RwLock<AppConfig>,
|
||||
config_initialized: RwLock<bool>,
|
||||
app_base_url: RwLock<Option<Url>>,
|
||||
menu_temporarily_visible: RwLock<bool>,
|
||||
debug_mode: bool,
|
||||
debug_log_file: Mutex<Option<fs::File>>,
|
||||
}
|
||||
|
||||
fn focus_main_window(app: &AppHandle) {
|
||||
@@ -372,6 +506,7 @@ fn trigger_new_window(app: &AppHandle) {
|
||||
}
|
||||
|
||||
apply_settings_to_window(&handle, &window);
|
||||
maybe_open_devtools(&handle, &window);
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
});
|
||||
@@ -467,10 +602,65 @@ fn inject_chat_link_intercept(webview: &Webview) {
|
||||
let _ = webview.eval(CHAT_LINK_INTERCEPT_SCRIPT);
|
||||
}
|
||||
|
||||
fn handle_toggle_devtools(app: &AppHandle) {
|
||||
#[cfg(any(debug_assertions, feature = "devtools"))]
|
||||
{
|
||||
let windows: Vec<_> = app.webview_windows().into_values().collect();
|
||||
let any_open = windows.iter().any(|w| w.is_devtools_open());
|
||||
for window in &windows {
|
||||
if any_open {
|
||||
window.close_devtools();
|
||||
} else {
|
||||
window.open_devtools();
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(debug_assertions, feature = "devtools")))]
|
||||
{
|
||||
let _ = app;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_open_debug_log() {
|
||||
let log_path = match get_debug_log_path() {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if !log_path.exists() {
|
||||
eprintln!("[ONYX DEBUG] Log file does not exist yet: {:?}", log_path);
|
||||
return;
|
||||
}
|
||||
|
||||
let url_path = log_path.to_string_lossy().replace('\\', "/");
|
||||
let _ = open_in_default_browser(&format!(
|
||||
"file:///{}",
|
||||
url_path.trim_start_matches('/')
|
||||
));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tauri Commands
|
||||
// ============================================================================
|
||||
|
||||
#[tauri::command]
|
||||
fn log_from_frontend(level: String, message: String, state: tauri::State<ConfigState>) {
|
||||
if !state.debug_mode {
|
||||
return;
|
||||
}
|
||||
let timestamp = format_utc_timestamp();
|
||||
let log_line = format!("[{}] [{}] {}", timestamp, level.to_uppercase(), message);
|
||||
|
||||
eprintln!("{}", log_line);
|
||||
|
||||
if let Ok(mut guard) = state.debug_log_file.lock() {
|
||||
if let Some(ref mut file) = *guard {
|
||||
let _ = writeln!(file, "{}", log_line);
|
||||
let _ = file.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current server URL
|
||||
#[tauri::command]
|
||||
fn get_server_url(state: tauri::State<ConfigState>) -> String {
|
||||
@@ -657,6 +847,7 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
|
||||
}
|
||||
|
||||
apply_settings_to_window(&app, &window);
|
||||
maybe_open_devtools(&app, &window);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -936,6 +1127,30 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
menu.append(&help_menu)?;
|
||||
}
|
||||
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
let toggle_devtools_item = MenuItem::with_id(
|
||||
app,
|
||||
MENU_TOGGLE_DEVTOOLS_ID,
|
||||
"Toggle DevTools",
|
||||
true,
|
||||
Some("F12"),
|
||||
)?;
|
||||
let open_log_item = MenuItem::with_id(
|
||||
app,
|
||||
MENU_OPEN_DEBUG_LOG_ID,
|
||||
"Open Debug Log",
|
||||
true,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
let debug_menu = SubmenuBuilder::new(app, "Debug")
|
||||
.item(&toggle_devtools_item)
|
||||
.item(&open_log_item)
|
||||
.build()?;
|
||||
menu.append(&debug_menu)?;
|
||||
}
|
||||
|
||||
app.set_menu(menu)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1027,8 +1242,20 @@ fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> {
|
||||
// ============================================================================
|
||||
|
||||
fn main() {
|
||||
// Load config at startup
|
||||
let (config, config_initialized) = load_config();
|
||||
let debug_mode = is_debug_mode();
|
||||
|
||||
let debug_log_file = if debug_mode {
|
||||
eprintln!("[ONYX DEBUG] Debug mode enabled");
|
||||
if let Some(path) = get_debug_log_path() {
|
||||
eprintln!("[ONYX DEBUG] Frontend logs: {}", path.display());
|
||||
}
|
||||
eprintln!("[ONYX DEBUG] DevTools will open automatically");
|
||||
eprintln!("[ONYX DEBUG] Capturing console.log/warn/error/info/debug from webview");
|
||||
init_debug_log_file()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
@@ -1059,6 +1286,8 @@ fn main() {
|
||||
config_initialized: RwLock::new(config_initialized),
|
||||
app_base_url: RwLock::new(None),
|
||||
menu_temporarily_visible: RwLock::new(false),
|
||||
debug_mode,
|
||||
debug_log_file: Mutex::new(debug_log_file),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_server_url,
|
||||
@@ -1077,7 +1306,8 @@ fn main() {
|
||||
start_drag_window,
|
||||
toggle_menu_bar,
|
||||
show_menu_bar_temporarily,
|
||||
hide_menu_bar_temporary
|
||||
hide_menu_bar_temporary,
|
||||
log_from_frontend
|
||||
])
|
||||
.on_menu_event(|app, event| match event.id().as_ref() {
|
||||
"open_docs" => open_docs(),
|
||||
@@ -1086,6 +1316,8 @@ fn main() {
|
||||
"open_settings" => open_settings(app),
|
||||
"show_menu_bar" => handle_menu_bar_toggle(app),
|
||||
"hide_window_decorations" => handle_decorations_toggle(app),
|
||||
MENU_TOGGLE_DEVTOOLS_ID => handle_toggle_devtools(app),
|
||||
MENU_OPEN_DEBUG_LOG_ID => handle_open_debug_log(),
|
||||
_ => {}
|
||||
})
|
||||
.setup(move |app| {
|
||||
@@ -1119,6 +1351,7 @@ fn main() {
|
||||
inject_titlebar(window.clone());
|
||||
|
||||
apply_settings_to_window(&app_handle, &window);
|
||||
maybe_open_devtools(&app_handle, &window);
|
||||
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
@@ -1128,6 +1361,14 @@ fn main() {
|
||||
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
|
||||
inject_chat_link_intercept(webview);
|
||||
|
||||
{
|
||||
let app = webview.app_handle();
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
inject_console_capture(webview);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);
|
||||
|
||||
93
greptile.json
Normal file
93
greptile.json
Normal file
@@ -0,0 +1,93 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of TODO(name): ... or TODO(1234): ..."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Remove temporary debugging code before merging to production, especially tenant-specific debugging logs."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
6
uv.lock
generated
6
uv.lock
generated
@@ -4106,7 +4106,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.9.1"
|
||||
version = "3.9.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
@@ -4114,9 +4114,9 @@ dependencies = [
|
||||
{ name = "regex" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -143,6 +143,7 @@ module.exports = {
|
||||
"**/src/app/**/utils/*.test.ts",
|
||||
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
|
||||
"**/src/refresh-components/**/*.test.ts",
|
||||
"**/src/sections/**/*.test.ts",
|
||||
// Add more patterns here as you add more unit tests
|
||||
],
|
||||
},
|
||||
@@ -156,6 +157,8 @@ module.exports = {
|
||||
"**/src/components/**/*.test.tsx",
|
||||
"**/src/lib/**/*.test.tsx",
|
||||
"**/src/refresh-components/**/*.test.tsx",
|
||||
"**/src/hooks/**/*.test.tsx",
|
||||
"**/src/sections/**/*.test.tsx",
|
||||
// Add more patterns here as you add more integration tests
|
||||
],
|
||||
},
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
"types": "./src/icons/index.ts",
|
||||
"default": "./src/icons/index.ts"
|
||||
},
|
||||
"./illustrations": {
|
||||
"types": "./src/illustrations/index.ts",
|
||||
"default": "./src/illustrations/index.ts"
|
||||
},
|
||||
"./types": {
|
||||
"types": "./src/types.ts",
|
||||
"default": "./src/types.ts"
|
||||
|
||||
99
web/lib/opal/scripts/README.md
Normal file
99
web/lib/opal/scripts/README.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# SVG-to-TSX Conversion Scripts
|
||||
|
||||
## Overview
|
||||
|
||||
Integrating `@svgr/webpack` into the TypeScript compiler was not working via the recommended route (Next.js webpack configuration).
|
||||
The automatic SVG-to-React component conversion was causing compilation issues and import resolution problems.
|
||||
Therefore, we manually convert each SVG into a TSX file using SVGR CLI with a custom template.
|
||||
|
||||
All scripts in this directory should be run from the **opal package root** (`web/lib/opal/`).
|
||||
|
||||
## Directory Layout
|
||||
|
||||
```
|
||||
web/lib/opal/
|
||||
├── scripts/ # SVG conversion tooling (this directory)
|
||||
│ ├── convert-svg.sh # Converts SVGs into React components
|
||||
│ └── icon-template.js # Shared SVGR template (used for both icons and illustrations)
|
||||
├── src/
|
||||
│ ├── icons/ # Small, single-colour icons (stroke = currentColor)
|
||||
│ └── illustrations/ # Larger, multi-colour illustrations (colours preserved)
|
||||
└── package.json
|
||||
```
|
||||
|
||||
## Icons vs Illustrations
|
||||
|
||||
| | Icons | Illustrations |
|
||||
|---|---|---|
|
||||
| **Import path** | `@opal/icons` | `@opal/illustrations` |
|
||||
| **Location** | `src/icons/` | `src/illustrations/` |
|
||||
| **Colour** | Overridable via `currentColor` | Fixed — original SVG colours preserved |
|
||||
| **Script flag** | (none) | `--illustration` |
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
### `icon-template.js`
|
||||
|
||||
A custom SVGR template that generates components with the following features:
|
||||
- Imports `IconProps` from `@opal/types` for consistent typing
|
||||
- Supports the `size` prop for controlling icon dimensions
|
||||
- Includes `width` and `height` attributes bound to the `size` prop
|
||||
- Maintains all standard SVG props (className, color, title, etc.)
|
||||
|
||||
### `convert-svg.sh`
|
||||
|
||||
Converts an SVG into a React component. Behaviour depends on the mode:
|
||||
|
||||
**Icon mode** (default):
|
||||
- Strips `stroke`, `stroke-opacity`, `width`, and `height` attributes
|
||||
- Adds `width={size}`, `height={size}`, and `stroke="currentColor"`
|
||||
- Result is colour-overridable via CSS `color` property
|
||||
|
||||
**Illustration mode** (`--illustration`):
|
||||
- Strips only `width` and `height` attributes (all colours preserved)
|
||||
- Adds `width={size}` and `height={size}`
|
||||
- Does **not** add `stroke="currentColor"` — illustrations keep their original colours
|
||||
|
||||
Both modes automatically delete the source SVG file after successful conversion.
|
||||
|
||||
## Adding New SVGs
|
||||
|
||||
### Icons
|
||||
|
||||
```sh
|
||||
# From web/lib/opal/
|
||||
./scripts/convert-svg.sh src/icons/my-icon.svg
|
||||
```
|
||||
|
||||
Then add the export to `src/icons/index.ts`:
|
||||
```ts
|
||||
export { default as SvgMyIcon } from "@opal/icons/my-icon";
|
||||
```
|
||||
|
||||
### Illustrations
|
||||
|
||||
```sh
|
||||
# From web/lib/opal/
|
||||
./scripts/convert-svg.sh --illustration src/illustrations/my-illustration.svg
|
||||
```
|
||||
|
||||
Then add the export to `src/illustrations/index.ts`:
|
||||
```ts
|
||||
export { default as SvgMyIllustration } from "@opal/illustrations/my-illustration";
|
||||
```
|
||||
|
||||
## Manual Conversion
|
||||
|
||||
If you prefer to run the SVGR command directly:
|
||||
|
||||
**For icons** (strips colours):
|
||||
```sh
|
||||
bunx @svgr/cli <file>.svg --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}' --template scripts/icon-template.js > <file>.tsx
|
||||
```
|
||||
|
||||
**For illustrations** (preserves colours):
|
||||
```sh
|
||||
bunx @svgr/cli <file>.svg --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["width","height"]}}]}' --template scripts/icon-template.js > <file>.tsx
|
||||
```
|
||||
|
||||
After running either manual command, remember to delete the original SVG file.
|
||||
123
web/lib/opal/scripts/convert-svg.sh
Executable file
123
web/lib/opal/scripts/convert-svg.sh
Executable file
@@ -0,0 +1,123 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Convert an SVG file to a TypeScript React component.
|
||||
#
|
||||
# By default, converts to a colour-overridable icon (stroke colours stripped, replaced with currentColor).
|
||||
# With --illustration, converts to a fixed-colour illustration (all original colours preserved).
|
||||
#
|
||||
# Usage (from the opal package root — web/lib/opal/):
|
||||
# ./scripts/convert-svg.sh src/icons/<filename.svg>
|
||||
# ./scripts/convert-svg.sh --illustration src/illustrations/<filename.svg>
|
||||
|
||||
ILLUSTRATION=false
|
||||
|
||||
# Parse flags
|
||||
while [[ "$1" == --* ]]; do
|
||||
case "$1" in
|
||||
--illustration)
|
||||
ILLUSTRATION=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown flag: $1" >&2
|
||||
echo "Usage: ./scripts/convert-svg.sh [--illustration] <filename.svg>" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: ./scripts/convert-svg.sh [--illustration] <filename.svg>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SVG_FILE="$1"
|
||||
|
||||
# Check if file exists
|
||||
if [ ! -f "$SVG_FILE" ]; then
|
||||
echo "Error: File '$SVG_FILE' not found" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if it's an SVG file
|
||||
if [[ ! "$SVG_FILE" == *.svg ]]; then
|
||||
echo "Error: File must have .svg extension" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the base name without extension
|
||||
BASE_NAME="${SVG_FILE%.svg}"
|
||||
|
||||
# Build the SVGO config based on mode
|
||||
if [ "$ILLUSTRATION" = true ]; then
|
||||
# Illustrations: only strip width and height (preserve all colours)
|
||||
SVGO_CONFIG='{"plugins":[{"name":"removeAttrs","params":{"attrs":["width","height"]}}]}'
|
||||
else
|
||||
# Icons: strip stroke, stroke-opacity, width, and height
|
||||
SVGO_CONFIG='{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}'
|
||||
fi
|
||||
|
||||
# Resolve the template path relative to this script (not the caller's CWD)
|
||||
SCRIPT_DIR="$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Run the conversion into a temp file so a failed run doesn't destroy an existing .tsx
|
||||
TMPFILE="${BASE_NAME}.tsx.tmp"
|
||||
bunx @svgr/cli "$SVG_FILE" --typescript --svgo-config "$SVGO_CONFIG" --template "${SCRIPT_DIR}/icon-template.js" > "$TMPFILE"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
# Verify the temp file has content before replacing the destination
|
||||
if [ ! -s "$TMPFILE" ]; then
|
||||
rm -f "$TMPFILE"
|
||||
echo "Error: Output file was not created or is empty" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mv "$TMPFILE" "${BASE_NAME}.tsx" || { echo "Error: Failed to move temp file" >&2; exit 1; }
|
||||
|
||||
# Post-process the file to add width and height attributes bound to the size prop
|
||||
# Using perl for cross-platform compatibility (works on macOS, Linux, Windows with WSL)
|
||||
# Note: perl -i returns 0 even on some failures, so we validate the output
|
||||
|
||||
perl -i -pe 's/<svg/<svg width={size} height={size}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add width/height attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Icons additionally get stroke="currentColor"
|
||||
if [ "$ILLUSTRATION" = false ]; then
|
||||
perl -i -pe 's/\{\.\.\.props\}/stroke="currentColor" {...props}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add stroke attribute" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Verify the file still exists and has content after post-processing
|
||||
if [ ! -s "${BASE_NAME}.tsx" ]; then
|
||||
echo "Error: Output file corrupted during post-processing" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify required attributes are present in the output
|
||||
if ! grep -q 'width={size}' "${BASE_NAME}.tsx" || ! grep -q 'height={size}' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Post-processing did not add required attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# For icons, also verify stroke="currentColor" was added
|
||||
if [ "$ILLUSTRATION" = false ]; then
|
||||
if ! grep -q 'stroke="currentColor"' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Post-processing did not add stroke=\"currentColor\"" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Created ${BASE_NAME}.tsx"
|
||||
rm "$SVG_FILE"
|
||||
echo "Deleted $SVG_FILE"
|
||||
else
|
||||
rm -f "$TMPFILE"
|
||||
echo "Error: Conversion failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -11,6 +11,8 @@ export {
|
||||
Interactive,
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSidebarVariantProps,
|
||||
type InteractiveBaseSidebarProminenceTypes,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core/interactive/components";
|
||||
|
||||
@@ -104,6 +104,44 @@ The foundational layer for all clickable surfaces in the design system. Defines
|
||||
| **Active** | `action-link-05` | `action-link-05` |
|
||||
| **Disabled** | `action-link-03` | `action-link-03` |
|
||||
|
||||
### Sidebar (unselected)
|
||||
|
||||
> No CSS `:active` state — only hover/transient and selected.
|
||||
|
||||
**Background**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **Rest** | `transparent` |
|
||||
| **Hover / Transient** | `background-tint-03` |
|
||||
| **Disabled** | `transparent` |
|
||||
|
||||
**Foreground**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **Rest** | `text-03` |
|
||||
| **Hover / Transient** | `text-04` |
|
||||
| **Disabled** | `text-01` |
|
||||
|
||||
### Sidebar (selected)
|
||||
|
||||
> Completely static — hover and transient have no effect.
|
||||
|
||||
**Background**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **All states** | `background-tint-00` |
|
||||
| **Disabled** | `transparent` |
|
||||
|
||||
**Foreground**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **All states** | `text-03` (icon: `text-02`) |
|
||||
| **Disabled** | `text-01` |
|
||||
|
||||
## Sub-components
|
||||
|
||||
| Sub-component | Role |
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import Link from "next/link";
|
||||
import type { Route } from "next";
|
||||
import "@opal/core/interactive/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
@@ -26,18 +28,28 @@ type InteractiveBaseSelectVariantProps = {
|
||||
selected?: boolean;
|
||||
};
|
||||
|
||||
type InteractiveBaseSidebarProminenceTypes = "light";
|
||||
type InteractiveBaseSidebarVariantProps = {
|
||||
variant: "sidebar";
|
||||
prominence?: InteractiveBaseSidebarProminenceTypes;
|
||||
selected?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Discriminated union tying `variant` to `prominence`.
|
||||
*
|
||||
* - `"none"` accepts no prominence (`prominence` must not be provided)
|
||||
* - `"select"` accepts an optional prominence (defaults to `"light"`) and
|
||||
* an optional `selected` boolean that switches foreground to action-link colours
|
||||
* - `"sidebar"` accepts an optional prominence (defaults to `"light"`) and
|
||||
* an optional `selected` boolean for the focused/active-item state
|
||||
* - `"default"`, `"action"`, and `"danger"` accept an optional prominence
|
||||
* (defaults to `"primary"`)
|
||||
*/
|
||||
type InteractiveBaseVariantProps =
|
||||
| { variant?: "none"; prominence?: never; selected?: never }
|
||||
| InteractiveBaseSelectVariantProps
|
||||
| InteractiveBaseSidebarVariantProps
|
||||
| {
|
||||
variant?: InteractiveBaseVariantTypes;
|
||||
prominence?: InteractiveBaseProminenceTypes;
|
||||
@@ -218,7 +230,8 @@ function InteractiveBase({
|
||||
...props
|
||||
}: InteractiveBaseProps) {
|
||||
const effectiveProminence =
|
||||
prominence ?? (variant === "select" ? "light" : "primary");
|
||||
prominence ??
|
||||
(variant === "select" || variant === "sidebar" ? "light" : "primary");
|
||||
const classes = cn(
|
||||
"interactive",
|
||||
!props.onClick && !href && "!cursor-default !select-auto",
|
||||
@@ -417,9 +430,9 @@ function InteractiveContainer({
|
||||
// so all styling (backgrounds, rounding, overflow) lives on one element.
|
||||
if (href) {
|
||||
return (
|
||||
<a
|
||||
<Link
|
||||
ref={ref as React.Ref<HTMLAnchorElement>}
|
||||
href={href}
|
||||
href={href as Route}
|
||||
target={target}
|
||||
rel={rel}
|
||||
{...(sharedProps as React.HTMLAttributes<HTMLAnchorElement>)}
|
||||
@@ -482,6 +495,8 @@ export {
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSelectVariantProps,
|
||||
type InteractiveBaseSidebarVariantProps,
|
||||
type InteractiveBaseSidebarProminenceTypes,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
};
|
||||
|
||||
@@ -419,3 +419,23 @@
|
||||
) {
|
||||
@apply bg-background-tint-00;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar + Light
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"] {
|
||||
@apply bg-transparent;
|
||||
}
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"][data-transient="true"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"][data-selected="true"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-00;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user