Compare commits

..

1 Commits

Author SHA1 Message Date
Dane Urban
c9155a8767 Hide deprecated columns 2026-03-02 18:16:18 -08:00
517 changed files with 4680 additions and 14857 deletions

View File

@@ -15,7 +15,6 @@ permissions:
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
secrets: inherit
permissions:
contents: read
id-token: write

View File

@@ -335,6 +335,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
@@ -470,13 +471,13 @@ jobs:
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
onyx-lite-tests:
no-vectordb-tests:
needs: [build-backend-image, build-integration-image]
runs-on:
[
runs-on,
runner=4cpu-linux-arm64,
"run-id=${{ github.run_id }}-onyx-lite-tests",
"run-id=${{ github.run_id }}-no-vectordb-tests",
"extras=ecr-cache",
]
timeout-minutes: 45
@@ -494,12 +495,13 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for Onyx Lite Docker Compose
- name: Create .env file for no-vectordb Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
@@ -507,23 +509,28 @@ jobs:
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
DISABLE_VECTOR_DB=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
EOF
# Start only the services needed for Onyx Lite (Postgres + API server)
- name: Start Docker containers (onyx-lite)
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
- name: Start Docker containers (no-vectordb)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
relational_db \
cache \
minio \
api_server \
background \
-d
id: start_docker_onyx_lite
id: start_docker_no_vectordb
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script (onyx-lite)..."
echo "Starting wait-for-service script (no-vectordb)..."
start_time=$(date +%s)
timeout=300
while true; do
@@ -545,14 +552,14 @@ jobs:
sleep 5
done
- name: Run Onyx Lite Integration Tests
- name: Run No-VectorDB Integration Tests
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running onyx-lite integration tests..."
echo "Running no-vectordb integration tests..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -563,38 +570,39 @@ jobs:
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e TEST_WEB_HOSTNAME=test-runner \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/tests/no_vectordb
- name: Dump API server logs (onyx-lite)
- name: Dump API server logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
- name: Dump all-container logs (onyx-lite)
- name: Dump all-container logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
- name: Upload logs (onyx-lite)
- name: Upload logs (no-vectordb)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-onyx-lite
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
name: docker-all-logs-no-vectordb
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
- name: Stop Docker containers (onyx-lite)
- name: Stop Docker containers (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
multitenant-tests:
needs:
@@ -736,7 +744,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -268,11 +268,10 @@ jobs:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -280,7 +279,6 @@ jobs:
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
@@ -592,108 +590,6 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-tests-lite:
needs: [build-web-image, build-backend-image]
name: Playwright Tests (lite)
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-lite"
- "extras=ecr-cache"
timeout-minutes: 30
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-npm-
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
MOCK_LLM_RESPONSE=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers (lite)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Run Playwright tests (lite)
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
path: ./web/output/playwright/
retention-days: 30
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
@@ -790,7 +686,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests, playwright-tests-lite]
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status

58
.vscode/launch.json vendored
View File

@@ -40,7 +40,19 @@
}
},
{
"name": "Celery",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
@@ -241,6 +253,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -485,6 +526,21 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
@@ -586,45 +617,6 @@ Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Error Handling
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
handling consistent across the entire backend.
```python
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ✅ Good
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
# ✅ Good — no extra message needed
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
# ✅ Good — upstream service with dynamic status code
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
# ❌ Bad — using HTTPException directly
raise HTTPException(status_code=404, detail="Session not found")
# ❌ Bad — starlette constant
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
```
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
category is needed, add it there first — do not invent ad-hoc codes.
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
status code is dynamic (comes from the upstream response), use `status_code_override`:
```python
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
```
## Best Practices
In addition to the other content in this file, best practices for contributing

View File

@@ -1,37 +0,0 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: 4a1e4b1c89d2
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "4a1e4b1c89d2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -1,34 +0,0 @@
"""make scim_user_mapping.external_id nullable
Revision ID: a3b8d9e2f1c4
Revises: 2664261bfaab
Create Date: 2026-03-02
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "a3b8d9e2f1c4"
down_revision = "2664261bfaab"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=True,
)
def downgrade() -> None:
# Delete any rows where external_id is NULL before re-applying NOT NULL
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=False,
)

View File

@@ -0,0 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -11,10 +11,11 @@ from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.cache.factory import get_cache_backend
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -141,7 +142,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
"""
Get license metadata from cache.
Get license metadata from Redis cache.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
@@ -149,34 +150,38 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
Returns:
LicenseMetadata if cached, None otherwise
"""
cache = get_cache_backend(tenant_id=tenant_id)
cached = cache.get(LICENSE_METADATA_KEY)
if not cached:
return None
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_replica_client(tenant_id=tenant)
try:
cached_str = (
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
cached = redis_client.get(LICENSE_METADATA_KEY)
if cached:
try:
cached_str: str
if isinstance(cached, bytes):
cached_str = cached.decode("utf-8")
else:
cached_str = str(cached)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
return None
def invalidate_license_cache(tenant_id: str | None = None) -> None:
"""
Invalidate the license metadata cache (not the license itself).
Deletes the cached LicenseMetadata. The actual license in the database
is not affected. Delete is idempotent if the key doesn't exist, this
is a no-op.
This deletes the cached LicenseMetadata from Redis. The actual license
in the database is not affected. Redis delete is idempotent - if the
key doesn't exist, this is a no-op.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
"""
cache = get_cache_backend(tenant_id=tenant_id)
cache.delete(LICENSE_METADATA_KEY)
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
redis_client.delete(LICENSE_METADATA_KEY)
logger.info("License cache invalidated")
@@ -187,7 +192,7 @@ def update_license_cache(
tenant_id: str | None = None,
) -> LicenseMetadata:
"""
Update the cache with license metadata.
Update the Redis cache with license metadata.
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
1. Frontend needs status to show appropriate UI/banners
@@ -206,7 +211,7 @@ def update_license_cache(
from ee.onyx.utils.license import get_license_status
tenant = tenant_id or get_current_tenant_id()
cache = get_cache_backend(tenant_id=tenant_id)
redis_client = get_redis_client(tenant_id=tenant)
used_seats = get_used_seats(tenant)
status = get_license_status(payload, grace_period_end)
@@ -225,7 +230,7 @@ def update_license_cache(
stripe_subscription_id=payload.stripe_subscription_id,
)
cache.set(
redis_client.set(
LICENSE_METADATA_KEY,
metadata.model_dump_json(),
ex=LICENSE_CACHE_TTL_SECONDS,

View File

@@ -126,16 +126,12 @@ class ScimDAL(DAL):
def create_user_mapping(
self,
external_id: str | None,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a SCIM mapping for a user.
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
allows this). The mapping still marks the user as SCIM-managed.
"""
"""Create a mapping between a SCIM externalId and an Onyx user."""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
@@ -274,13 +270,8 @@ class ScimDAL(DAL):
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
# unless they were explicitly linked via SCIM provisioning.
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
@@ -330,37 +321,34 @@ class ScimDAL(DAL):
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Sync the SCIM mapping for a user.
If a mapping already exists, its fields are updated (including
setting ``external_id`` to ``None`` when the IdP omits it).
If no mapping exists and ``new_external_id`` is provided, a new
mapping is created. A mapping is never deleted here — SCIM-managed
users must retain their mapping to remain visible in ``GET /Users``.
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
elif new_external_id:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_mappings_batch(
self, user_ids: list[UUID]

View File

@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -472,9 +471,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name,
time_last_modified_by_user=func.now(),
is_up_to_date=DISABLE_VECTOR_DB,
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -777,7 +774,8 @@ def update_user_group(
cc_pair_ids=user_group_update.cc_pair_ids,
)
if cc_pairs_updated and not DISABLE_VECTOR_DB:
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -26,6 +26,7 @@ import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -41,6 +42,7 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import StripePublishableKeyResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.billing.service import BillingServiceError
from ee.onyx.server.billing.service import (
create_checkout_session as create_checkout_service,
)
@@ -56,8 +58,6 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -169,23 +169,26 @@ async def create_checkout_session(
if seats is not None:
used_seats = get_used_seats(tenant_id)
if seats < used_seats:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Cannot subscribe with fewer seats than current usage. "
raise HTTPException(
status_code=400,
detail=f"Cannot subscribe with fewer seats than current usage. "
f"You have {used_seats} active users/integrations but requested {seats} seats.",
)
# Build redirect URL for after checkout completion
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
return await create_checkout_service(
billing_period=billing_period,
seats=seats,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
try:
return await create_checkout_service(
billing_period=billing_period,
seats=seats,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/create-customer-portal-session")
@@ -203,15 +206,18 @@ async def create_customer_portal_session(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
raise HTTPException(status_code=400, detail="No license found")
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
try:
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/billing-information")
@@ -234,9 +240,9 @@ async def get_billing_information(
# Check circuit breaker (self-hosted only)
if _is_billing_circuit_open():
raise OnyxError(
OnyxErrorCode.SERVICE_UNAVAILABLE,
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
raise HTTPException(
status_code=503,
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
)
try:
@@ -244,11 +250,11 @@ async def get_billing_information(
license_data=license_data,
tenant_id=tenant_id,
)
except OnyxError as e:
except BillingServiceError as e:
# Open circuit breaker on connection failures (self-hosted only)
if e.status_code in (502, 503, 504):
_open_billing_circuit()
raise
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/seats/update")
@@ -268,25 +274,31 @@ async def update_seats(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
raise HTTPException(status_code=400, detail="No license found")
# Validate that new seat count is not less than current used seats
used_seats = get_used_seats(tenant_id)
if request.new_seat_count < used_seats:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Cannot reduce seats below current usage. "
raise HTTPException(
status_code=400,
detail=f"Cannot reduce seats below current usage. "
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
)
# Note: Don't store license here - the control plane may still be processing
# the subscription update. The frontend should call /license/claim after a
# short delay to get the freshly generated license.
return await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
try:
result = await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
# Note: Don't store license here - the control plane may still be processing
# the subscription update. The frontend should call /license/claim after a
# short delay to get the freshly generated license.
return result
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/stripe-publishable-key")
@@ -317,18 +329,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Stripe publishable key is not configured",
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
@@ -339,17 +351,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to fetch Stripe publishable key",
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -22,8 +22,6 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -33,6 +31,15 @@ logger = setup_logger()
_REQUEST_TIMEOUT = 30.0
class BillingServiceError(Exception):
"""Exception raised for billing service errors."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(self.message)
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
"""Build headers for proxy requests (self-hosted).
@@ -94,7 +101,7 @@ async def _make_billing_request(
Response JSON as dict
Raises:
OnyxError: If request fails
BillingServiceError: If request fails
"""
base_url = _get_base_url()
@@ -121,17 +128,11 @@ async def _make_billing_request(
except Exception:
pass
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail,
status_code_override=e.response.status_code,
)
raise BillingServiceError(detail, e.response.status_code)
except httpx.RequestError:
logger.exception("Failed to connect to billing service")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
)
raise BillingServiceError("Failed to connect to billing service", 502)
async def create_checkout_session(

View File

@@ -223,15 +223,6 @@ def get_active_scim_token(
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
# Derive the IdP domain from the first synced user as a heuristic.
idp_domain: str | None = None
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
if mappings:
user = dal.get_user(mappings[0].user_id)
if user and "@" in user.email:
idp_domain = user.email.rsplit("@", 1)[1]
return ScimTokenResponse(
id=token.id,
name=token.name,
@@ -239,7 +230,6 @@ def get_active_scim_token(
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
idp_domain=idp_domain,
)

View File

@@ -14,6 +14,7 @@ import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
@@ -34,8 +35,6 @@ from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -128,9 +127,9 @@ async def claim_license(
2. Without session_id: Re-claim using existing license for auth
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License claiming is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License claiming is only available for self-hosted deployments",
)
try:
@@ -147,16 +146,15 @@ async def claim_license(
# Re-claim using existing license for auth
metadata = get_license_metadata(db_session)
if not metadata or not metadata.tenant_id:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No license found. Provide session_id after checkout.",
raise HTTPException(
status_code=400,
detail="No license found. Provide session_id after checkout.",
)
license_row = get_license(db_session)
if not license_row or not license_row.license_data:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No license found in database",
raise HTTPException(
status_code=400, detail="No license found in database"
)
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
@@ -175,7 +173,7 @@ async def claim_license(
license_data = data.get("license")
if not license_data:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
raise HTTPException(status_code=404, detail="No license in response")
# Verify signature before persisting
payload = verify_license_signature(license_data)
@@ -201,14 +199,12 @@ async def claim_license(
detail = error_data.get("detail", detail)
except Exception:
pass
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
)
raise HTTPException(status_code=status_code, detail=detail)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
raise HTTPException(
status_code=502, detail="Failed to connect to license server"
)
@@ -225,9 +221,9 @@ async def upload_license(
The license file must be cryptographically signed by Onyx.
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License upload is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License upload is only available for self-hosted deployments",
)
try:
@@ -238,14 +234,14 @@ async def upload_license(
# Remove any stray whitespace/newlines from user input
license_data = license_data.strip()
except UnicodeDecodeError:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
raise HTTPException(status_code=400, detail="Invalid license file format")
# Verify cryptographic signature - this is the only validation needed
# The license's tenant_id identifies the customer in control plane, not locally
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
# Persist to DB and update cache
upsert_license(db_session, license_data)
@@ -301,9 +297,9 @@ async def delete_license(
Admin only - removes license from database and invalidates cache.
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License deletion is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License deletion is only available for self-hosted deployments",
)
try:

View File

@@ -46,6 +46,7 @@ from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
@@ -55,7 +56,6 @@ from ee.onyx.configs.license_enforcement_config import (
)
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from shared_configs.contextvars import get_current_tenant_id
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
"[license_enforcement] No license, allowing community features"
)
is_gated = False
except CACHE_TRANSIENT_ERRORS as e:
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to cache connectivity issues
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:

View File

@@ -423,63 +423,15 @@ def create_user(
email = user_resource.userName.strip()
# Check for existing user — if they exist but aren't SCIM-managed yet,
# link them to the IdP rather than rejecting with 409.
external_id: str | None = user_resource.externalId
scim_username: str = user_resource.userName.strip()
fields: ScimMappingFields = _fields_from_resource(user_resource)
existing_user = dal.get_user_by_email(email)
if existing_user:
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
if existing_mapping:
return _scim_error_response(409, f"User with email {email} already exists")
# Adopt pre-existing user into SCIM management.
# Reactivating a deactivated user consumes a seat, so enforce the
# seat limit the same way replace_user does.
if user_resource.active and not existing_user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
existing_user,
is_active=user_resource.active,
**({"personal_name": personal_name} if personal_name else {}),
)
try:
dal.create_user_mapping(
external_id=external_id,
user_id=existing_user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(
existing_user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
# Only enforce seat limit for net-new users — adopting a pre-existing
# user doesn't consume a new seat.
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
@@ -497,21 +449,21 @@ def create_user(
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Always create a SCIM mapping so that the user is marked as
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
try:
# Create SCIM mapping when externalId is provided — this is how the IdP
# correlates this user on subsequent requests. Per RFC 7643, externalId
# is optional and assigned by the provisioning client.
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
if external_id:
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(

View File

@@ -365,7 +365,6 @@ class ScimTokenResponse(BaseModel):
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
idp_domain: str | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):

View File

@@ -170,10 +170,7 @@ class ScimProvider(ABC):
formatted=user.personal_name or "",
)
if not user.personal_name:
# Derive a reasonable name from the email so that SCIM spec tests
# see non-empty givenName / familyName for every user resource.
local = user.email.split("@")[0] if user.email else ""
return ScimName(givenName=local, familyName="", formatted=local)
return ScimName(givenName="", familyName="", formatted="")
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],

View File

@@ -6,7 +6,6 @@ from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
@@ -126,7 +125,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# syncing) means indexed data may need protection.
settings.application_status = _BLOCKING_STATUS
settings.ee_features_enabled = False
except CACHE_TRANSIENT_ERRORS as e:
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")
# Fail closed - disable EE features if we can't verify license
settings.ee_features_enabled = False

View File

@@ -2,13 +2,12 @@ from datetime import datetime
from datetime import timedelta
import jwt
from fastapi import HTTPException
from fastapi import Request
from onyx.configs.app_configs import DATA_PLANE_SECRET
from onyx.configs.app_configs import EXPECTED_API_KEY
from onyx.configs.app_configs import JWT_ALGORITHM
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -33,24 +32,22 @@ async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=[JWT_ALGORITHM])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS, "Insufficient permissions"
)
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise OnyxError(OnyxErrorCode.TOKEN_EXPIRED, "Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise OnyxError(OnyxErrorCode.INVALID_TOKEN, "Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi_users import exceptions
@@ -11,8 +12,6 @@ from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -31,7 +30,7 @@ async def impersonate_user(
except exceptions.UserNotExists:
detail = f"User has no tenant mapping: {impersonate_request.email=}"
logger.warning(detail)
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, detail)
raise HTTPException(status_code=422, detail=detail)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
@@ -42,7 +41,7 @@ async def impersonate_user(
f"User not found in tenant: {impersonate_request.email=} {tenant_id=}"
)
logger.warning(detail)
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, detail)
raise HTTPException(status_code=422, detail=detail)
token = await get_redis_strategy().write_token(user_to_impersonate)

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
@@ -17,8 +18,6 @@ from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -34,7 +33,7 @@ async def get_anonymous_user_path_api(
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Tenant not found")
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
@@ -51,21 +50,21 @@ async def set_anonymous_user_path_api(
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise OnyxError(
OnyxErrorCode.CONFLICT,
"The anonymous user path is already in use. Please choose a different path.",
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"An unexpected error occurred while modifying the anonymous user path",
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@@ -78,10 +77,10 @@ async def login_as_anonymous_user(
anonymous_user_path, db_session
)
if not tenant_id:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Tenant not found")
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, "Anonymous user is not enabled")
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)

View File

@@ -21,6 +21,7 @@ import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.server.tenants.access import control_plane_dep
@@ -42,8 +43,6 @@ from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -117,14 +116,9 @@ async def create_customer_portal_session(
try:
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"stripe_customer_portal_url": portal_url}
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create customer portal session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create customer portal session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-checkout-session")
@@ -140,14 +134,9 @@ async def create_checkout_session(
try:
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
return {"stripe_checkout_url": checkout_url}
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create checkout session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create checkout session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
@@ -158,20 +147,15 @@ async def create_subscription_session(
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
raise HTTPException(status_code=400, detail="Tenant ID not found")
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create subscription session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create subscription session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
@@ -202,18 +186,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Stripe publishable key is not configured",
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
@@ -224,15 +208,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to fetch Stripe publishable key",
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -4,6 +4,7 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -40,8 +41,6 @@ from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
@@ -117,9 +116,9 @@ async def get_or_provision_tenant(
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to provision tenant. Please try again later.",
raise HTTPException(
status_code=500,
detail="Failed to provision tenant. Please try again later.",
)
@@ -145,18 +144,18 @@ async def create_tenant(
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to provision tenant.")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
async def provision_tenant(tenant_id: str, email: str) -> None:
if not MULTI_TENANT:
raise OnyxError(OnyxErrorCode.UNAUTHORIZED, "Multi-tenancy is not enabled")
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if user_owns_a_tenant(email):
raise OnyxError(
OnyxErrorCode.CONFLICT, "User already belongs to an organization"
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
@@ -176,8 +175,8 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR, f"Failed to create tenant: {str(e)}"
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)

View File

@@ -25,6 +25,7 @@ import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Header
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
@@ -35,8 +36,6 @@ from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import is_license_valid
from ee.onyx.utils.license import verify_license_signature
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -47,9 +46,9 @@ router = APIRouter(prefix="/proxy")
def _check_license_enforcement_enabled() -> None:
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
if not LICENSE_ENFORCEMENT_ENABLED:
raise OnyxError(
OnyxErrorCode.NOT_IMPLEMENTED,
"Proxy endpoints are only available on cloud data plane",
raise HTTPException(
status_code=501,
detail="Proxy endpoints are only available on cloud data plane",
)
@@ -82,9 +81,8 @@ def _extract_license_from_header(
"""
if not authorization or not authorization.startswith("Bearer "):
if required:
raise OnyxError(
OnyxErrorCode.UNAUTHENTICATED,
"Missing or invalid authorization header",
raise HTTPException(
status_code=401, detail="Missing or invalid authorization header"
)
return None
@@ -112,10 +110,10 @@ def verify_license_auth(
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, f"Invalid license: {e}")
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
if not allow_expired and not is_license_valid(payload):
raise OnyxError(OnyxErrorCode.TOKEN_EXPIRED, "License has expired")
raise HTTPException(status_code=401, detail="License has expired")
return payload
@@ -199,12 +197,12 @@ async def forward_to_control_plane(
except Exception:
pass
logger.error(f"Control plane returned {status_code}: {detail}")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
)
raise HTTPException(status_code=status_code, detail=detail)
except httpx.RequestError:
logger.exception("Failed to connect to control plane")
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, "Failed to connect to control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
# -----------------------------------------------------------------------------
@@ -296,9 +294,9 @@ async def proxy_claim_license(
if not tenant_id or not license_data:
logger.error(f"Control plane returned incomplete claim response: {result}")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Control plane returned incomplete license data",
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
return ClaimLicenseResponse(
@@ -328,7 +326,7 @@ async def proxy_create_customer_portal_session(
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "License missing tenant_id")
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
@@ -369,7 +367,7 @@ async def proxy_billing_information(
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "License missing tenant_id")
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
@@ -400,12 +398,12 @@ async def proxy_license_fetch(
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "License missing tenant_id")
raise HTTPException(status_code=401, detail="License missing tenant_id")
if tenant_id != license_payload.tenant_id:
raise OnyxError(
OnyxErrorCode.UNAUTHORIZED,
"Cannot fetch license for a different tenant",
raise HTTPException(
status_code=403,
detail="Cannot fetch license for a different tenant",
)
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
@@ -413,9 +411,9 @@ async def proxy_license_fetch(
license_data = result.get("license")
if not license_data:
logger.error(f"Control plane returned incomplete license response: {result}")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Control plane returned incomplete license data",
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
# Return license to caller - self-hosted instance stores it via /api/license/claim
@@ -434,7 +432,7 @@ async def proxy_seat_update(
Returns the regenerated license in the response for the caller to store.
"""
if not license_payload.tenant_id:
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "License missing tenant_id")
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
@@ -11,8 +12,6 @@ from onyx.db.auth import get_user_count
from onyx.db.engine.sql_engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -31,14 +30,13 @@ async def leave_organization(
tenant_id = get_current_tenant_id()
if current_user.email != user_email.user_email:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You can only leave the organization as yourself",
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise OnyxError(OnyxErrorCode.USER_NOT_FOUND, "User not found")
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
@@ -55,9 +53,9 @@ async def leave_organization(
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Failed to remove user from control plane: {str(e)}",
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
@@ -12,8 +13,6 @@ from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -33,7 +32,7 @@ async def request_invite(
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
@@ -65,7 +64,7 @@ async def accept_invite(
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to accept invitation")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
@@ -80,4 +79,4 @@ async def deny_invite(
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Failed to deny invitation")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -156,8 +153,3 @@ def delete_user_group(
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if DISABLE_VECTOR_DB:
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)

View File

@@ -120,6 +120,7 @@ from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
@@ -200,14 +201,13 @@ def user_needs_to_be_verified() -> bool:
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
from onyx.cache.factory import get_cache_backend
cache = get_cache_backend(tenant_id=tenant_id)
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
redis_client = get_redis_client(tenant_id=tenant_id)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is None:
return False
assert isinstance(value, bytes)
return int(value.decode("utf-8")) == 1

View File

@@ -0,0 +1,142 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector.
"""Result of extracting document IDs and hierarchy nodes from a connector."""
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
"""
raw_id_to_parent: dict[str, str | None]
doc_ids: set[str]
hierarchy_nodes: list[HierarchyNode]
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
return None
class BatchResult(BaseModel):
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> BatchResult:
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID dict so that failed-to-retrieve documents are not accidentally pruned.
ID set so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: dict[str, str | None] = {}
ids: set[str] = set()
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
ids.add(item.raw_node_id)
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids[failed_id] = None
ids.add(failed_id)
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
ids.add(item.id)
return ids, hierarchy_nodes
def extract_ids_from_runnable_connector(
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_raw_id_to_parent: dict[str, str | None] = {}
all_connector_doc_ids: set[str] = set()
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
raw_id_to_parent=all_raw_id_to_parent,
doc_ids=all_connector_doc_ids,
hierarchy_nodes=all_hierarchy_nodes,
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
transform_vespa_chunks_to_opensearch_chunks,
)
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -48,7 +47,6 @@ from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
@@ -148,12 +146,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
task_logger.error(err_str)
return False
with (
get_session_with_current_tenant() as db_session,
get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client,
):
with get_session_with_current_tenant() as db_session:
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
@@ -168,7 +161,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
httpx_client=vespa_client,
)
sanitized_doc_start_time = time.monotonic()

View File

@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
super().progress(tag, amount)
def _resolve_and_update_document_parents(
db_session: Session,
redis_client: Redis,
source: DocumentSource,
raw_id_to_parent: dict[str, str | None],
) -> None:
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
each document and bulk-update the DB. Mirrors the resolution logic in
run_docfetching.py."""
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent_id in raw_id_to_parent.items():
if raw_parent_id is None:
continue
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
resolved[doc_id] = node_id if found else source_node_id
if not resolved:
return
update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
task_logger.info(
f"Pruning: resolved and updated parent hierarchy for "
f"{len(resolved)} documents (source={source.value})"
)
"""Jobs / utils for kicking off pruning tasks."""
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
all_connector_doc_ids = extraction_result.doc_ids
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
ensure_source_node_exists(redis_client, db_session, source)
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=source,
source=cc_pair.connector.source,
commit=True,
is_connector_public=is_connector_public,
)
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=source,
source=cc_pair.connector.source,
entries=cache_entries,
)
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
db_session=db_session,
redis_client=redis_client,
source=source,
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
"Pruning set collected: "

View File

@@ -520,7 +520,6 @@ def process_user_file_impl(
task_logger.exception(
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()
@@ -676,7 +675,6 @@ def delete_user_file_impl(
task_logger.exception(
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()
@@ -851,7 +849,6 @@ def project_sync_user_file_impl(
task_logger.exception(
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -59,12 +59,6 @@ def _run_auto_llm_update() -> None:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
@@ -106,26 +100,12 @@ def _run_scheduled_eval() -> None:
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(

View File

@@ -75,41 +75,31 @@ def _claim_next_processing_file(db_session: Session) -> UUID | None:
return file_id
def _claim_next_deleting_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
"""Claim the next DELETING file.
No status transition needed — the impl deletes the row on success.
The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
file_id = db_session.execute(
select(UserFile.id)
.where(UserFile.status == UserFileStatus.DELETING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
).scalar_one_or_none()
# Commit to release the row lock promptly.
db_session.commit()
return file_id
def _claim_next_sync_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
def _claim_next_sync_file(db_session: Session) -> UUID | None:
"""Claim the next file needing project/persona sync.
No status transition needed — the impl clears the sync flags on
success. The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
file_id = db_session.execute(
select(UserFile.id)
.where(
sa.and_(
@@ -123,10 +113,7 @@ def _claim_next_sync_file(
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
).scalar_one_or_none()
db_session.commit()
return file_id
@@ -148,14 +135,11 @@ def drain_processing_loop(tenant_id: str) -> None:
file_id = _claim_next_processing_file(session)
if file_id is None:
break
try:
process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to process user file {file_id}")
process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
def drain_delete_loop(tenant_id: str) -> None:
@@ -165,21 +149,16 @@ def drain_delete_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
file_id = _claim_next_deleting_file(session)
if file_id is None:
break
try:
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to delete user file {file_id}")
failed.add(file_id)
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
def drain_project_sync_loop(tenant_id: str) -> None:
@@ -189,18 +168,13 @@ def drain_project_sync_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_sync_file(session, exclude_ids=failed)
file_id = _claim_next_sync_file(session)
if file_id is None:
break
try:
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to sync user file {file_id}")
failed.add(file_id)
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)

View File

@@ -12,15 +12,9 @@ def _build_redis_backend(tenant_id: str) -> CacheBackend:
return RedisCacheBackend(redis_pool.get_client(tenant_id))
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.postgres_backend import PostgresCacheBackend
return PostgresCacheBackend(tenant_id)
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
CacheBackendType.POSTGRES: _build_postgres_backend,
# CacheBackendType.POSTGRES will be added in a follow-up PR.
}

View File

@@ -1,20 +1,6 @@
import abc
from enum import Enum
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
TTL_KEY_NOT_FOUND = -2
TTL_NO_EXPIRY = -1
CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError)
"""Exception types that represent transient cache connectivity / operational
failures. Callers that want to fail-open (or fail-closed) on cache errors
should catch this tuple instead of bare ``Exception``.
When adding a new ``CacheBackend`` implementation, add its transient error
base class(es) here so all call-sites pick it up automatically."""
class CacheBackendType(str, Enum):
REDIS = "redis"
@@ -40,14 +26,6 @@ class CacheLock(abc.ABC):
def owned(self) -> bool:
raise NotImplementedError
def __enter__(self) -> "CacheLock":
if not self.acquire():
raise RuntimeError("Failed to acquire lock")
return self
def __exit__(self, *args: object) -> None:
self.release()
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
@@ -87,11 +65,7 @@ class CacheBackend(abc.ABC):
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds.
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
"""
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
raise NotImplementedError
# -- distributed lock --------------------------------------------------

View File

@@ -1,323 +0,0 @@
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
for distributed locking, and a polling loop for the BLPOP pattern.
"""
import hashlib
import struct
import time
import uuid
from contextlib import AbstractContextManager
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.db.models import CacheStore
_LIST_KEY_PREFIX = "_q:"
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
# lists whose names share a prefix (e.g. _q:mylist2:...).
_LIST_KEY_RANGE_TERMINATOR = ";"
_LIST_ITEM_TTL_SECONDS = 3600
_LOCK_POLL_INTERVAL = 0.1
_BLPOP_POLL_INTERVAL = 0.25
def _list_item_key(key: str) -> str:
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
collision when concurrent rpush calls occur within the same nanosecond.
"""
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
def _to_bytes(value: str | bytes | int | float) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
# ------------------------------------------------------------------
# Lock
# ------------------------------------------------------------------
class PostgresCacheLock(CacheLock):
"""Advisory-lock-based distributed lock.
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
to the session's connection; releasing or closing the session frees it.
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
``timeout`` seconds. They are released when ``release()`` is
called or when the session is closed.
"""
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
self._lock_id = lock_id
self._timeout = timeout
self._tenant_id = tenant_id
self._session_cm: AbstractContextManager[Session] | None = None
self._session: Session | None = None
self._acquired = False
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
self._session = self._session_cm.__enter__()
try:
if not blocking:
return self._try_lock()
effective_timeout = blocking_timeout or self._timeout
deadline = (
(time.monotonic() + effective_timeout) if effective_timeout else None
)
while True:
if self._try_lock():
return True
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(_LOCK_POLL_INTERVAL)
finally:
if not self._acquired:
self._close_session()
def release(self) -> None:
if not self._acquired or self._session is None:
return
try:
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
finally:
self._acquired = False
self._close_session()
def owned(self) -> bool:
return self._acquired
def _close_session(self) -> None:
if self._session_cm is not None:
try:
self._session_cm.__exit__(None, None, None)
finally:
self._session_cm = None
self._session = None
def _try_lock(self) -> bool:
assert self._session is not None
result = self._session.execute(
select(func.pg_try_advisory_lock(self._lock_id))
).scalar()
if result:
self._acquired = True
return True
return False
# ------------------------------------------------------------------
# Backend
# ------------------------------------------------------------------
class PostgresCacheBackend(CacheBackend):
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
Each operation opens and closes its own database session so the backend
is safe to share across threads. Tenant isolation is handled by
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
"""
def __init__(self, tenant_id: str) -> None:
self._tenant_id = tenant_id
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.value).where(
CacheStore.key == key,
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
value = session.execute(stmt).scalar_one_or_none()
if value is None:
return None
return bytes(value)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
value_bytes = _to_bytes(value)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=ex)
if ex is not None
else None
)
stmt = (
pg_insert(CacheStore)
.values(key=key, value=value_bytes, expires_at=expires_at)
.on_conflict_do_update(
index_elements=[CacheStore.key],
set_={"value": value_bytes, "expires_at": expires_at},
)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def delete(self, key: str) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(delete(CacheStore).where(CacheStore.key == key))
session.commit()
def exists(self, key: str) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = (
select(CacheStore.key)
.where(
CacheStore.key == key,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.limit(1)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
return session.execute(stmt).first() is not None
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
stmt = (
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def ttl(self, key: str) -> int:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
result = session.execute(stmt).first()
if result is None:
return TTL_KEY_NOT_FOUND
expires_at: datetime | None = result[0]
if expires_at is None:
return TTL_NO_EXPIRY
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
if remaining <= 0:
return TTL_KEY_NOT_FOUND
return int(remaining)
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return PostgresCacheLock(
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
)
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
if timeout <= 0:
raise ValueError(
"PostgresCacheBackend.blpop requires timeout > 0. "
"timeout=0 would block the calling thread indefinitely "
"with no way to interrupt short of process termination."
)
from onyx.db.engine.sql_engine import get_session_with_tenant
deadline = time.monotonic() + timeout
while True:
for key in keys:
lower = f"{_LIST_KEY_PREFIX}{key}:"
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
stmt = (
select(CacheStore)
.where(
CacheStore.key >= lower,
CacheStore.key < upper,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.order_by(CacheStore.key)
.limit(1)
.with_for_update(skip_locked=True)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
row = session.execute(stmt).scalars().first()
if row is not None:
value = bytes(row.value) if row.value else b""
session.delete(row)
session.commit()
return (key.encode(), value)
if time.monotonic() >= deadline:
return None
time.sleep(_BLPOP_POLL_INTERVAL)
# -- helpers -----------------------------------------------------------
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
return struct.unpack("q", h[:8])[0]
# ------------------------------------------------------------------
# Periodic cleanup
# ------------------------------------------------------------------
def cleanup_expired_cache_entries() -> None:
"""Delete rows whose ``expires_at`` is in the past.
Called by the periodic poller every 5 minutes.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
delete(CacheStore).where(
CacheStore.expires_at.is_not(None),
CacheStore.expires_at < func.now(),
)
)
session.commit()

View File

@@ -1,52 +1,57 @@
from uuid import UUID
from onyx.cache.interface import CacheBackend
from redis.client import Redis
# Redis key prefixes for chat message processing
PREFIX = "chatprocessing"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 30 * 60 # 30 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""Generate the cache key for a chat session processing fence.
"""
Generate the Redis key for a chat session processing a message.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
The fence key string (tenant_id is automatically added by the Redis client)
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_processing_status(
chat_session_id: UUID, cache: CacheBackend, value: bool
chat_session_id: UUID, redis_client: Redis, value: bool
) -> None:
"""Set or clear the fence for a chat session processing a message.
"""
Set or clear the fence for a chat session processing a message.
If the key exists, a message is being processed.
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: The Redis client to use
value: True to set the fence, False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if value:
cache.set(fence_key, 0, ex=FENCE_TTL)
redis_client.set(fence_key, 0, ex=FENCE_TTL)
else:
cache.delete(fence_key)
redis_client.delete(fence_key)
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session is processing a message.
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session is processing a message.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: The Redis client to use
Returns:
True if the chat session is processing a message, False otherwise
"""
return cache.exists(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
return bool(redis_client.exists(fence_key))

View File

@@ -52,7 +52,6 @@ from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import MemoryToolResponseSnapshot
from onyx.tools.models import PythonToolRichResponse
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -967,13 +966,6 @@ def run_llm_loop(
):
generated_images = tool_response.rich_response.generated_images
# Extract generated_files if this is a code interpreter response
generated_files = None
if isinstance(tool_response.rich_response, PythonToolRichResponse):
generated_files = (
tool_response.rich_response.generated_files or None
)
# Persist memory if this is a memory tool response
memory_snapshot: MemoryToolResponseSnapshot | None = None
if isinstance(tool_response.rich_response, MemoryToolResponse):
@@ -1025,7 +1017,6 @@ def run_llm_loop(
tool_call_response=saved_response,
search_docs=displayed_docs or search_docs,
generated_images=generated_images,
generated_files=generated_files,
)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)

View File

@@ -11,10 +11,9 @@ from contextvars import Token
from uuid import UUID
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.cache.interface import CacheBackend
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
@@ -80,6 +79,7 @@ from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
llm: LLM | None = None
chat_session: ChatSession | None = None
cache: CacheBackend | None = None
redis_client: Redis | None = None
user_id = user.id
if user.is_anonymous:
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
)
simple_chat_history.insert(0, summary_simple)
cache = get_cache_backend()
redis_client = get_redis_client()
reset_cancel_status(
chat_session.id,
cache,
redis_client,
)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
return check_stop_signal(chat_session.id, redis_client)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
redis_client=redis_client,
value=True,
)
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
reset_llm_mock_response(mock_response_token)
try:
if cache is not None and chat_session is not None:
if redis_client is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
redis_client=redis_client,
value=False,
)
except Exception:

View File

@@ -1,5 +1,4 @@
import json
import mimetypes
from sqlalchemy.orm import Session
@@ -13,41 +12,14 @@ from onyx.db.chat import create_db_search_doc
from onyx.db.models import ChatMessage
from onyx.db.models import ToolCall
from onyx.db.tools import create_tool_call_no_commit
from onyx.file_store.models import FileDescriptor
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.tools.models import ToolCallInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _extract_referenced_file_descriptors(
tool_calls: list[ToolCallInfo],
message_text: str,
) -> list[FileDescriptor]:
"""Extract FileDescriptors for code interpreter files referenced in the message text."""
descriptors: list[FileDescriptor] = []
for tool_call_info in tool_calls:
if not tool_call_info.generated_files:
continue
for gen_file in tool_call_info.generated_files:
file_id = (
gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else ""
)
if file_id and file_id in message_text:
mime_type, _ = mimetypes.guess_type(gen_file.filename)
descriptors.append(
FileDescriptor(
id=file_id,
type=mime_type_to_chat_file_type(mime_type),
name=gen_file.filename,
)
)
return descriptors
def _create_and_link_tool_calls(
tool_calls: list[ToolCallInfo],
assistant_message: ChatMessage,
@@ -325,14 +297,5 @@ def save_chat_turn(
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
)
# 8. Attach code interpreter generated files that the assistant actually
# referenced in its response, so they are available via load_all_chat_files
# on subsequent turns. Files not mentioned are intermediate artifacts.
if message_text:
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
if referenced:
existing_files = assistant_message.files or []
assistant_message.files = existing_files + referenced
# Finally save the messages, tool calls, and docs
db_session.commit()

View File

@@ -1,58 +1,65 @@
from uuid import UUID
from onyx.cache.interface import CacheBackend
from redis.client import Redis
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 10 * 60 # 10 minutes
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
def _get_fence_key(chat_session_id: UUID) -> str:
"""Generate the cache key for a chat session stop signal fence.
"""
Generate the Redis key for a chat session stop signal fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
The fence key string (tenant_id is automatically added by the Redis client)
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
"""Set or clear the stop signal fence for a chat session.
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
value: True to set the fence (stop signal), False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if not value:
cache.delete(fence_key)
redis_client.delete(fence_key)
return
cache.set(fence_key, 0, ex=FENCE_TTL)
redis_client.set(fence_key, 0, ex=FENCE_TTL)
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session should continue (not stopped).
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
cache: Tenant-aware cache backend
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
Returns:
True if the session should continue, False if it should stop
"""
return not cache.exists(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
return not bool(redis_client.exists(fence_key))
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
"""Clear the stop signal for a chat session.
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
"""
cache.delete(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
redis_client.delete(fence_key)

View File

@@ -495,7 +495,14 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
# Individual worker concurrency settings
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)
@@ -812,9 +819,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
# Tool Configs
#####
# Code Interpreter Service Configuration
CODE_INTERPRETER_BASE_URL = os.environ.get(
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
)
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
@@ -895,9 +900,6 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
)
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
)
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")

View File

@@ -84,6 +84,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (

View File

@@ -943,9 +943,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
page
),
)
)
@@ -995,7 +992,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=page_id,
)
)

View File

@@ -781,5 +781,4 @@ def build_slim_document(
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
)

View File

@@ -902,11 +902,6 @@ class JiraConnector(
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
current_offset += 1

View File

@@ -385,7 +385,6 @@ class IndexingDocument(Document):
class SlimDocument(BaseModel):
id: str
external_access: ExternalAccess | None = None
parent_hierarchy_raw_node_id: str | None = None
class HierarchyNode(BaseModel):

View File

@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
drive_name: str,
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
return SlimDocument(
id=driveitem.id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
def _convert_sitepage_to_slim_document(
site_page: dict[str, Any],
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
return SlimDocument(
id=id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
@@ -1600,22 +1594,12 @@ class SharepointConnector(
)
)
parent_hierarchy_url: str | None = None
if drive_web_url:
parent_hierarchy_url = self._get_parent_hierarchy_url(
site_url, drive_web_url, drive_name, driveitem
)
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem,
drive_name,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
driveitem, drive_name, ctx, self.graph_client
)
)
except Exception as e:
@@ -1635,10 +1619,7 @@ class SharepointConnector(
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
site_page, ctx, self.graph_client
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:

View File

@@ -565,7 +565,6 @@ def _get_all_doc_ids(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
parent_hierarchy_raw_node_id=channel_id,
)
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
@@ -247,7 +246,6 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
is_up_to_date=DISABLE_VECTOR_DB,
time_last_modified_by_user=func.now(),
)
db_session.add(new_document_set_row)
@@ -338,8 +336,7 @@ def update_document_set(
)
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(

View File

@@ -1,7 +1,5 @@
"""CRUD operations for HierarchyNode."""
from collections import defaultdict
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
return {doc_id: parent_id for doc_id, parent_id in results}
def update_document_parent_hierarchy_nodes(
db_session: Session,
doc_parent_map: dict[str, int | None],
commit: bool = True,
) -> int:
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
Only updates rows whose current value differs from the desired value to
avoid unnecessary writes.
Args:
db_session: SQLAlchemy session
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
commit: Whether to commit the transaction
Returns:
Number of documents actually updated
"""
if not doc_parent_map:
return 0
doc_ids = list(doc_parent_map.keys())
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
by_parent: dict[int | None, list[str]] = defaultdict(list)
for doc_id, desired_parent_id in doc_parent_map.items():
current = existing.get(doc_id)
if current == desired_parent_id or doc_id not in existing:
continue
by_parent[desired_parent_id].append(doc_id)
updated = 0
for desired_parent_id, ids in by_parent.items():
db_session.query(Document).filter(Document.id.in_(ids)).update(
{Document.parent_hierarchy_node_id: desired_parent_id},
synchronize_session=False,
)
updated += len(ids)
if commit:
db_session.commit()
elif updated:
db_session.flush()
return updated
def update_hierarchy_node_permissions(
db_session: Session,
raw_node_id: str,

View File

@@ -532,7 +532,6 @@ def fetch_default_model(
) -> ModelConfiguration | None:
model_config = db_session.scalar(
select(ModelConfiguration)
.options(selectinload(ModelConfiguration.llm_provider))
.join(LLMModelFlow)
.where(
ModelConfiguration.is_visible == True, # noqa: E712

View File

@@ -2822,17 +2822,8 @@ class LLMProvider(Base):
postgresql.JSONB(), nullable=True
)
# Deprecated: use LLMModelFlow with CHAT flow type instead
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
# Deprecated: use LLMModelFlow with VISION flow type instead
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Auto mode: models, visibility, and defaults are managed by GitHub config
@@ -4926,9 +4917,7 @@ class ScimUserMapping(Base):
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str | None] = mapped_column(
String, unique=True, index=True, nullable=True
)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
@@ -4985,25 +4974,3 @@ class CodeInterpreterServer(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
class CacheStore(Base):
"""Key-value cache table used by ``PostgresCacheBackend``.
Replaces Redis for simple KV caching, locks, and list operations
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
Intentionally separate from ``KVStore``:
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
- Holds ephemeral data (tokens, stop signals, lock state) not
persistent application config, so cleanup can be aggressive.
"""
__tablename__ = "cache_store"
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -52,7 +52,7 @@ def create_user_files(
) -> CategorizedFilesResult:
# Categorize the files
categorized_files = categorize_uploaded_files(files, db_session)
categorized_files = categorize_uploaded_files(files)
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)

View File

@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified; DB is not in a valid state.")
raise RuntimeError("No search settings specified, DB is not in a valid state")
return latest_settings

View File

@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type

View File

@@ -26,10 +26,11 @@ def get_default_document_index(
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated. Updates are applied to both indices.
WARNING: In that case, get_all_document_indices should be used.
need to be updated, updates are applied to both indices.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
@@ -50,26 +51,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
@@ -100,7 +86,8 @@ def get_all_document_indices(
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks are not currently supported so we hardcode appropriate values.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
@@ -136,36 +123,13 @@ def get_all_document_indices(
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)

View File

@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
secondary_embedding_dim: int | None,
secondary_embedding_precision: EmbeddingPrecision | None,
# NOTE: We do not support large chunks right now.
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
multitenant: bool = False,
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
f"Expected {MULTI_TENANT}, got {multitenant}."
)
tenant_id = get_current_tenant_id()
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
self._real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
self._secondary_real_index: OpenSearchDocumentIndex | None = None
if self.secondary_index_name:
if secondary_embedding_dim is None or secondary_embedding_precision is None:
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
self._secondary_real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=self.secondary_index_name,
embedding_dim=secondary_embedding_dim,
embedding_precision=secondary_embedding_precision,
)
@staticmethod
def register_multitenant_indices(
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
) -> None:
self._real_index.verify_and_create_index_if_necessary(
# Only handle primary index for now, ignore secondary.
return self._real_index.verify_and_create_index_if_necessary(
primary_embedding_dim, primary_embedding_precision
)
if self.secondary_index_name:
if (
secondary_index_embedding_dim is None
or secondary_index_embedding_precision is None
):
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.verify_and_create_index_if_necessary(
secondary_index_embedding_dim, secondary_index_embedding_precision
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
# Convert IndexBatchParams to IndexingMetadata.
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
tenant_id: str, # noqa: ARG002
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
total_chunks_deleted += self._secondary_real_index.delete(
doc_id, chunk_count
)
return total_chunks_deleted
return self._real_index.delete(doc_id, chunk_count)
def update_single(
self,
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> None:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
f"Tried to update document {doc_id} with no updated fields or user fields."
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
try:
self._real_index.update([update_request])
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.update([update_request])
except NotFoundError:
logger.exception(
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "

View File

@@ -1,6 +1,5 @@
import json
import string
import time
from collections.abc import Callable
from collections.abc import Mapping
from datetime import datetime
@@ -19,7 +18,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.document_index.interfaces import VespaChunkRequest
@@ -340,18 +338,12 @@ def get_all_chunks_paginated(
params["continuation"] = continuation_token
response: httpx.Response | None = None
start_time = time.monotonic()
try:
with get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as http_client:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = (
f"Failed to get chunks from Vespa slice {slice_id} with continuation token "
f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds."
)
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"

View File

@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
persona_ids=persona_ids,
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
vespa_document_index.update([update_request])
vespa_document_index.update([update_request])
def delete_single(
self,
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
tenant_state = TenantState(
tenant_id=get_current_tenant_id(),
multitenant=MULTI_TENANT,
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
raise ValueError(
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
total_chunks_deleted = 0
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
total_chunks_deleted += vespa_document_index.delete(
document_id=doc_id, chunk_count=chunk_count
)
return total_chunks_deleted
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
def id_based_retrieval(
self,

View File

@@ -52,9 +52,7 @@ def replace_invalid_doc_id_characters(text: str) -> str:
return text.replace("'", "_")
def get_vespa_http_client(
no_timeout: bool = False, http2: bool = True, timeout: int | None = None
) -> httpx.Client:
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
"""
Configures and returns an HTTP client for communicating with Vespa,
including authentication if needed.
@@ -66,7 +64,7 @@ def get_vespa_http_client(
else None
),
verify=False if not MANAGED_VESPA else True,
timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT),
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=http2,
)

View File

@@ -1,101 +0,0 @@
"""
Standardized error codes for the Onyx backend.
Usage:
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired")
"""
from enum import Enum
class OnyxErrorCode(Enum):
"""
Each member is a tuple of (error_code_string, http_status_code).
The error_code_string is a stable, machine-readable identifier that
API consumers can match on. The http_status_code is the default HTTP
status to return.
"""
# ------------------------------------------------------------------
# Authentication (401)
# ------------------------------------------------------------------
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
INVALID_TOKEN = ("INVALID_TOKEN", 401)
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
CSRF_FAILURE = ("CSRF_FAILURE", 403)
# ------------------------------------------------------------------
# Authorization (403)
# ------------------------------------------------------------------
UNAUTHORIZED = ("UNAUTHORIZED", 403)
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
ADMIN_ONLY = ("ADMIN_ONLY", 403)
EE_REQUIRED = ("EE_REQUIRED", 403)
# ------------------------------------------------------------------
# Validation / Bad Request (400)
# ------------------------------------------------------------------
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
# ------------------------------------------------------------------
# Not Found (404)
# ------------------------------------------------------------------
NOT_FOUND = ("NOT_FOUND", 404)
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
# ------------------------------------------------------------------
# Conflict (409)
# ------------------------------------------------------------------
CONFLICT = ("CONFLICT", 409)
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
# ------------------------------------------------------------------
# Rate Limiting / Quotas (429 / 402)
# ------------------------------------------------------------------
RATE_LIMITED = ("RATE_LIMITED", 429)
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
# ------------------------------------------------------------------
# Connector / Credential Errors (400-range)
# ------------------------------------------------------------------
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
# ------------------------------------------------------------------
# Server Errors (5xx)
# ------------------------------------------------------------------
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:
self.code = code
self.status_code = status_code
def detail(self, message: str | None = None) -> dict[str, str]:
"""Build a structured error detail dict.
Returns a dict like:
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
If no message is supplied, the error code itself is used as the message.
"""
return {
"error_code": self.code,
"message": message or self.code,
}

View File

@@ -1,82 +0,0 @@
"""OnyxError — the single exception type for all Onyx business errors.
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
converts it into a JSON response with the standard
``{"error_code": "...", "message": "..."}`` shape.
Usage::
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
For upstream errors with a dynamic HTTP status (e.g. billing service),
use ``status_code_override``::
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail,
status_code_override=upstream_status,
)
"""
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import JSONResponse
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxError(Exception):
"""Structured error that maps to a specific ``OnyxErrorCode``.
Attributes:
error_code: The ``OnyxErrorCode`` enum member.
message: Human-readable message (defaults to the error code string).
status_code: HTTP status — either overridden or from the error code.
"""
def __init__(
self,
error_code: OnyxErrorCode,
message: str | None = None,
*,
status_code_override: int | None = None,
) -> None:
self.error_code = error_code
self.message = message or error_code.code
self._status_code_override = status_code_override
super().__init__(self.message)
@property
def status_code(self) -> int:
return self._status_code_override or self.error_code.status_code
def register_onyx_exception_handlers(app: FastAPI) -> None:
"""Register a global handler that converts ``OnyxError`` to JSON responses.
Must be called *after* the app is created but *before* it starts serving.
The handler logs at WARNING for 4xx and ERROR for 5xx.
"""
@app.exception_handler(OnyxError)
async def _handle_onyx_error(
request: Request, # noqa: ARG001
exc: OnyxError,
) -> JSONResponse:
status_code = exc.status_code
if status_code >= 500:
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
elif status_code >= 400:
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
return JSONResponse(
status_code=status_code,
content=exc.error_code.detail(exc.message),
)

View File

@@ -4,33 +4,39 @@ import base64
import json
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import Optional
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key prefix for OAuth state
OAUTH_STATE_PREFIX = "federated_oauth"
OAUTH_STATE_TTL = 300 # 5 minutes
# Default TTL for OAuth state (5 minutes)
OAUTH_STATE_TTL = 300
class OAuthSession:
"""Represents an OAuth session stored in the cache backend."""
"""Represents an OAuth session stored in Redis."""
def __init__(
self,
federated_connector_id: int,
user_id: str,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
):
self.federated_connector_id = federated_connector_id
self.user_id = user_id
self.redirect_uri = redirect_uri
self.additional_data = additional_data or {}
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for Redis storage."""
return {
"federated_connector_id": self.federated_connector_id,
"user_id": self.user_id,
@@ -39,7 +45,8 @@ class OAuthSession:
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
"""Create from dictionary retrieved from Redis."""
return cls(
federated_connector_id=data["federated_connector_id"],
user_id=data["user_id"],
@@ -51,27 +58,31 @@ class OAuthSession:
def generate_oauth_state(
federated_connector_id: int,
user_id: str,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
ttl: int = OAUTH_STATE_TTL,
) -> str:
"""
Generate a secure state parameter and store session data in the cache backend.
Generate a secure state parameter and store session data in Redis.
Args:
federated_connector_id: ID of the federated connector
user_id: ID of the user initiating OAuth
redirect_uri: Optional redirect URI after OAuth completion
additional_data: Any additional data to store with the session
ttl: Time-to-live in seconds for the cache key
ttl: Time-to-live in seconds for the Redis key
Returns:
Base64-encoded state parameter
"""
# Generate a random UUID for the state
state_uuid = uuid.uuid4()
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
# Convert UUID to base64 for URL-safe state parameter
state_bytes = state_uuid.bytes
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
# Create session object
session = OAuthSession(
federated_connector_id=federated_connector_id,
user_id=user_id,
@@ -79,9 +90,15 @@ def generate_oauth_state(
additional_data=additional_data,
)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
# Store in Redis with TTL
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
redis_client.set(
redis_key,
json.dumps(session.to_dict()),
ex=ttl,
)
logger.info(
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
@@ -108,15 +125,18 @@ def verify_oauth_state(state: str) -> OAuthSession:
state_bytes = base64.urlsafe_b64decode(padded_state)
state_uuid = uuid.UUID(bytes=state_bytes)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
# Look up in Redis
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
session_data = cache.get(cache_key)
session_data = cast(bytes, redis_client.get(redis_key))
if not session_data:
raise ValueError(f"OAuth state not found: {state}")
raise ValueError(f"OAuth state not found in Redis: {state}")
cache.delete(cache_key)
# Delete the key after retrieval (one-time use)
redis_client.delete(redis_key)
# Parse and return session
session_dict = json.loads(session_data)
return OAuthSession.from_dict(session_dict)

View File

@@ -1,11 +1,13 @@
import json
from typing import cast
from onyx.cache.interface import CacheBackend
from redis.client import Redis
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -18,27 +20,22 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, cache: CacheBackend | None = None) -> None:
self._cache = cache
def _get_cache(self) -> CacheBackend:
if self._cache is None:
from onyx.cache.factory import get_cache_backend
self._cache = get_cache_backend()
return self._cache
def __init__(self, redis_client: Redis | None = None) -> None:
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
# Not encrypted in Redis, but encrypted in Postgres
try:
self._get_cache().set(
self.redis_client.set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Cache backend fails
logger.error(
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
)
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
@@ -56,12 +53,16 @@ class PgRedisKVStore(KeyValueStore):
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
if not refresh_cache:
try:
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
if cached is not None:
return json.loads(cached.decode("utf-8"))
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
if not isinstance(redis_value, bytes):
raise ValueError(
f"Redis value for key '{key}' is not a bytes object"
)
return json.loads(redis_value.decode("utf-8"))
except Exception as e:
logger.error(
f"Failed to get value from cache for key '{key}': {str(e)}"
f"Failed to get value from Redis for key '{key}': {str(e)}"
)
with get_session_with_current_tenant() as db_session:
@@ -78,21 +79,21 @@ class PgRedisKVStore(KeyValueStore):
value = None
try:
self._get_cache().set(
self.redis_client.set(
REDIS_KEY_PREFIX + key,
json.dumps(value),
ex=KV_REDIS_KEY_EXPIRATION,
)
except Exception as e:
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self._get_cache().delete(REDIS_KEY_PREFIX + key)
self.redis_client.delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with get_session_with_current_tenant() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete()

View File

@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
usage as a dict with chat completion format instead of keeping it as
ResponseAPIUsage. Our patch creates a deep copy before modification.
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
to check for router calls, but when metadata is explicitly None (key exists with
value None), the default {} is not used
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
the real exception (e.g. AuthenticationError for wrong API key)
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
not iterable
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
against metadata being explicitly None. Triggered when Responses API bridge
passes **litellm_params containing metadata=None.
"""
import time
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
def _patch_responses_metadata_none() -> None:
"""
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
None (the key exists, so the default is not used), causing:
TypeError: argument of type 'NoneType' is not iterable
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
This happens when the Responses API bridge calls litellm.responses() with
**litellm_params which may contain metadata=None.
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
which does not guard against metadata being explicitly None. Same pattern exists
on line 1407 for async path.
"""
import litellm as _litellm
from functools import wraps
original_responses = _litellm.responses
if getattr(original_responses, "_metadata_patched", False):
return
@wraps(original_responses)
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("metadata") is None:
kwargs["metadata"] = {}
return original_responses(*args, **kwargs)
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
_litellm.responses = _patched_responses
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
"""
_patch_ollama_chunk_parser()
_patch_openai_responses_parallel_tool_calls()
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
_patch_azure_responses_should_fake_stream()
_patch_responses_api_usage_format()
_patch_logging_assembled_streaming_response()
_patch_responses_metadata_none()

View File

@@ -13,38 +13,44 @@ from datetime import datetime
import httpx
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import sync_auto_mode_models
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
# Redis key for caching the last updated timestamp (per-tenant)
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
def _get_cached_last_updated_at() -> datetime | None:
"""Get the cached last_updated_at timestamp from Redis."""
try:
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
if value is not None:
redis_client = get_redis_client()
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
if value and isinstance(value, bytes):
# Value is bytes, decode to string then parse as ISO format
return datetime.fromisoformat(value.decode("utf-8"))
except Exception as e:
logger.warning(f"Failed to get cached last_updated_at: {e}")
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
return None
def _set_cached_last_updated_at(updated_at: datetime) -> None:
"""Set the cached last_updated_at timestamp in Redis."""
try:
get_cache_backend().set(
_CACHE_KEY_LAST_UPDATED_AT,
redis_client = get_redis_client()
# Store as ISO format string, with 24 hour expiration
redis_client.set(
_REDIS_KEY_LAST_UPDATED_AT,
updated_at.isoformat(),
ex=_CACHE_TTL_SECONDS,
ex=60 * 60 * 24, # 24 hours
)
except Exception as e:
logger.warning(f"Failed to set cached last_updated_at: {e}")
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
def fetch_llm_recommendations_from_github(
@@ -142,8 +148,9 @@ def sync_llm_models_from_github(
def reset_cache() -> None:
"""Reset the cache timestamp. Useful for testing."""
"""Reset the cache timestamp in Redis. Useful for testing."""
try:
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
redis_client = get_redis_client()
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
except Exception as e:
logger.warning(f"Failed to reset cache: {e}")
logger.warning(f"Failed to reset cache in Redis: {e}")

View File

@@ -59,7 +59,6 @@ from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth_check import check_router_auth
@@ -445,8 +444,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error
)
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
@@ -146,11 +146,6 @@ class SlackRenderer(HTMLRenderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
@@ -223,48 +218,5 @@ class SlackRenderer(HTMLRenderer):
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n\n"

View File

@@ -92,7 +92,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
@@ -573,43 +572,6 @@ def _normalize_file_names_for_backwards_compatibility(
return file_names + file_locations[len(file_names) :]
def _fetch_and_check_file_connector_cc_pair_permissions(
connector_id: int,
user: User,
db_session: Session,
require_editable: bool,
) -> ConnectorCredentialPair:
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
has_requested_access = verify_user_has_access_to_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
user=user,
get_editable=require_editable,
)
if has_requested_access:
return cc_pair
# Special case: global curators should be able to manage files
# for public file connectors even when they are not the creator.
if (
require_editable
and user.role == UserRole.GLOBAL_CURATOR
and cc_pair.access_type == AccessType.PUBLIC
):
return cc_pair
raise HTTPException(
status_code=403,
detail="Access denied. User cannot manage files for this connector.",
)
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
@@ -621,7 +583,7 @@ def upload_files_api(
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
def list_connector_files(
connector_id: int,
user: User = Depends(current_curator_or_admin_user),
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ConnectorFilesResponse:
"""List all files in a file connector."""
@@ -634,13 +596,6 @@ def list_connector_files(
status_code=400, detail="This endpoint only works with file connectors"
)
_ = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=False,
)
file_locations = connector.connector_specific_config.get("file_locations", [])
file_names = connector.connector_specific_config.get("file_names", [])
@@ -690,7 +645,7 @@ def update_connector_files(
connector_id: int,
files: list[UploadFile] | None = File(None),
file_ids_to_remove: str = Form("[]"),
user: User = Depends(current_curator_or_admin_user),
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
"""
@@ -708,13 +663,12 @@ def update_connector_files(
)
# Get the connector-credential pair for indexing/pruning triggers
# and validate user permissions for file management.
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=True,
)
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
# Parse file IDs to remove
try:

View File

@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"version": "4.11.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"

View File

@@ -1133,8 +1133,7 @@ done
# Already deleted
service_deleted = True
else:
logger.error(f"Error deleting Service {service_name}: {e}")
raise
logger.warning(f"Error deleting Service {service_name}: {e}")
pod_deleted = False
try:
@@ -1149,8 +1148,7 @@ done
# Already deleted
pod_deleted = True
else:
logger.error(f"Error deleting Pod {pod_name}: {e}")
raise
logger.warning(f"Error deleting Pod {pod_name}: {e}")
# Wait for resources to be fully deleted to prevent 409 conflicts
# on immediate re-provisioning

View File

@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
# Prevent overlapping runs of this task
if not lock.acquire(blocking=False):
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
return
try:

View File

@@ -11,7 +11,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import delete_document_set as db_delete_document_set
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import insert_document_set
@@ -143,10 +142,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if DISABLE_VECTOR_DB:
db_session.refresh(document_set)
db_delete_document_set(document_set, db_session)
else:
if not DISABLE_VECTOR_DB:
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},

View File

@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
pass
def categorize_uploaded_files(
files: list[UploadFile], db_session: Session
) -> CategorizedFiles:
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
"""
Categorize uploaded files based on text extractability and tokenized length.
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
"""
results = CategorizedFiles()
default_model = fetch_default_llm_model(db_session)
llm = get_default_llm()
model_name = default_model.name if default_model else None
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
tokenizer = get_tokenizer(
model_name=llm.config.model_name, provider_type=llm.config.model_provider
)
# Check if threshold checks should be skipped
skip_threshold = False

View File

@@ -8,10 +8,10 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.cache.factory import get_shared_cache_backend
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
@@ -113,46 +113,60 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
def get_cached_etag() -> str | None:
cache = get_shared_cache_backend()
"""Get the cached GitHub ETag from Redis."""
redis_client = get_shared_redis_client()
try:
etag = cache.get(REDIS_KEY_ETAG)
etag = redis_client.get(REDIS_KEY_ETAG)
if etag:
return etag.decode("utf-8")
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
return None
except Exception as e:
logger.error(f"Failed to get cached etag: {e}")
logger.error(f"Failed to get cached etag from Redis: {e}")
return None
def get_last_fetch_time() -> datetime | None:
cache = get_shared_cache_backend()
"""Get the last fetch timestamp from Redis."""
redis_client = get_shared_redis_client()
try:
raw = cache.get(REDIS_KEY_FETCHED_AT)
if not raw:
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
if not fetched_at_str:
return None
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
decoded = (
fetched_at_str.decode("utf-8")
if isinstance(fetched_at_str, bytes)
else str(fetched_at_str)
)
last_fetch = datetime.fromisoformat(decoded)
# Defensively ensure timezone awareness
# fromisoformat() returns naive datetime if input lacks timezone
if last_fetch.tzinfo is None:
# Assume UTC for naive datetimes
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if timezone-aware
last_fetch = last_fetch.astimezone(timezone.utc)
return last_fetch
except Exception as e:
logger.error(f"Failed to get last fetch time from cache: {e}")
logger.error(f"Failed to get last fetch time from Redis: {e}")
return None
def save_fetch_metadata(etag: str | None) -> None:
cache = get_shared_cache_backend()
"""Save ETag and fetch timestamp to Redis."""
redis_client = get_shared_redis_client()
now = datetime.now(timezone.utc)
try:
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
if etag:
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
except Exception as e:
logger.error(f"Failed to save fetch metadata to cache: {e}")
logger.error(f"Failed to save fetch metadata to Redis: {e}")
def is_cache_stale() -> bool:
@@ -182,10 +196,11 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
if not is_cache_stale():
return
cache = get_shared_cache_backend()
lock = cache.lock(
# Acquire lock to prevent concurrent fetches
redis_client = get_shared_redis_client()
lock = redis_client.lock(
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
timeout=90,
timeout=90, # 90 second timeout for the lock
)
# Non-blocking acquire - if we can't get the lock, another request is handling it

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -10,8 +11,6 @@ from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.models import User
from onyx.db.search_settings import get_all_search_settings
from onyx.db.search_settings import get_current_db_embedding_provider
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.indexing.models import EmbeddingModelDetail
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
@@ -60,7 +59,7 @@ def test_embedding_configuration(
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("", response_model=list[EmbeddingModelDetail])
@@ -94,9 +93,8 @@ def delete_embedding_provider(
embedding_provider is not None
and provider_type == embedding_provider.provider_type
):
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"You can't delete a currently active model",
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, provider_type=provider_type)

View File

@@ -11,6 +11,7 @@ from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import ValidationError
from sqlalchemy.orm import Session
@@ -37,8 +38,6 @@ from onyx.db.llm import upsert_llm_provider
from onyx.db.llm import validate_persona_ids_exist
from onyx.db.models import User
from onyx.db.persona import user_can_access_persona
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
@@ -187,7 +186,7 @@ def _validate_llm_provider_change(
Only enforced in MULTI_TENANT mode.
Raises:
OnyxError: If api_base or custom_config changed without changing API key
HTTPException: If api_base or custom_config changed without changing API key
"""
if not MULTI_TENANT or api_key_changed:
return
@@ -201,9 +200,9 @@ def _validate_llm_provider_change(
)
if api_base_changed or custom_config_changed:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API base and/or custom config cannot be changed without changing the API key",
raise HTTPException(
status_code=400,
detail="API base and/or custom config cannot be changed without changing the API key",
)
@@ -223,7 +222,7 @@ def fetch_llm_provider_options(
for well_known_llm in well_known_llms:
if well_known_llm.name == provider_name:
return well_known_llm
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
@admin_router.post("/test")
@@ -282,7 +281,7 @@ def test_llm_configuration(
error_msg = test_llm(llm)
if error_msg:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.post("/test/default")
@@ -293,11 +292,11 @@ def test_default_provider(
llm = get_default_llm()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
error = test_llm(llm)
if error:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
raise HTTPException(status_code=400, detail=str(error))
@admin_router.get("/provider")
@@ -363,31 +362,35 @@ def put_llm_provider(
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Renaming providers is not currently supported",
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"Provider with name={llm_provider_upsert_request.name} already exists",
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists",
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist",
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -412,9 +415,9 @@ def put_llm_provider(
db_session, persona_ids
)
if missing_personas:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
raise HTTPException(
status_code=400,
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
)
# Remove duplicates while preserving order
seen: set[int] = set()
@@ -470,29 +473,19 @@ def put_llm_provider(
return result
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
force: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Cannot delete the default LLM provider",
)
try:
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/default")
@@ -532,9 +525,9 @@ def get_auto_config(
"""
config = fetch_llm_recommendations_from_github()
if not config:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Failed to fetch configuration from GitHub",
raise HTTPException(
status_code=502,
detail="Failed to fetch configuration from GitHub",
)
return config.model_dump()
@@ -691,13 +684,13 @@ def list_llm_providers_for_persona(
persona = fetch_persona_with_groups(db_session, persona_id)
if not persona:
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
raise HTTPException(status_code=404, detail="Persona not found")
# Verify user has access to this persona
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You don't have access to this assistant",
raise HTTPException(
status_code=403,
detail="You don't have access to this assistant",
)
is_admin = user.role == UserRole.ADMIN
@@ -851,9 +844,9 @@ def get_bedrock_available_models(
try:
bedrock = session.client("bedrock")
except Exception as e:
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
raise HTTPException(
status_code=400,
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
)
# Build model info dict from foundation models (modelId -> metadata)
@@ -972,14 +965,14 @@ def get_bedrock_available_models(
return results
except (ClientError, NoCredentialsError, BotoCoreError) as e:
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Failed to connect to AWS Bedrock: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to connect to AWS Bedrock: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Unexpected error fetching Bedrock models: {e}",
raise HTTPException(
status_code=500,
detail=f"Unexpected error fetching Bedrock models: {e}",
)
@@ -991,9 +984,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch Ollama models: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to fetch Ollama models: {e}",
)
models = response_json.get("models", [])
@@ -1010,9 +1003,9 @@ def get_ollama_available_models(
cleaned_api_base = request.api_base.strip().rstrip("/")
if not cleaned_api_base:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API base URL is required to fetch Ollama models.",
raise HTTPException(
status_code=400,
detail="API base URL is required to fetch Ollama models.",
)
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
@@ -1021,9 +1014,9 @@ def get_ollama_available_models(
# with the same response format
model_names = _get_ollama_available_model_names(cleaned_api_base)
if not model_names:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Ollama server",
raise HTTPException(
status_code=400,
detail="No models found from your Ollama server",
)
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
@@ -1125,9 +1118,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
response.raise_for_status()
return response.json()
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch OpenRouter models: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to fetch OpenRouter models: {e}",
)
@@ -1148,9 +1141,9 @@ def get_openrouter_available_models(
data = response_json.get("data", [])
if not isinstance(data, list) or len(data) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenRouter endpoint",
raise HTTPException(
status_code=400,
detail="No models found from your OpenRouter endpoint",
)
results: list[OpenRouterFinalModelResponse] = []
@@ -1185,9 +1178,8 @@ def get_openrouter_available_models(
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenRouter",
raise HTTPException(
status_code=400, detail="No compatible models found from OpenRouter"
)
sorted_results = sorted(results, key=lambda m: m.name.lower())

View File

@@ -6,11 +6,8 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
@@ -18,25 +15,20 @@ from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import delete_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.factory import get_default_document_index
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import update_unstructured_api_key
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
@@ -49,99 +41,110 @@ def set_new_search_settings(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
Creates a new SearchSettings row and cancels the previous secondary indexing
if any exists.
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments.
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider.
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
validate_contextual_rag_model(
provider_name=search_settings_new.contextual_rag_llm_provider,
model_name=search_settings_new.contextual_rag_llm_name,
db_session=db_session,
# TODO(andrei): Re-enable.
# NOTE Enable integration external dependency tests in test_search_settings.py
# when this is reenabled. They are currently skipped
logger.error("Setting new search settings is temporarily disabled.")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Setting new search settings is temporarily disabled.",
)
# if search_settings_new.index_name:
# logger.warning("Index name was specified by request, this is not suggested")
search_settings = get_current_search_settings(db_session)
# # Disallow contextual RAG for cloud deployments
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Contextual RAG disabled in Onyx Cloud",
# )
if search_settings_new.index_name is None:
# We define index name here.
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
# # Validate cloud provider exists or create new LiteLLM provider
# if search_settings_new.provider_type is not None:
# cloud_provider = get_embedding_provider_from_provider_type(
# db_session, provider_type=search_settings_new.provider_type
# )
secondary_search_settings = get_secondary_search_settings(db_session)
# if cloud_provider is None:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
# )
if secondary_search_settings:
# Cancel any background indexing jobs.
expire_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# validate_contextual_rag_model(
# provider_name=search_settings_new.contextual_rag_llm_provider,
# model_name=search_settings_new.contextual_rag_llm_name,
# db_session=db_session,
# )
# Mark previous model as a past model directly.
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
# search_settings = get_current_search_settings(db_session)
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# if search_settings_new.index_name is None:
# # We define index name here
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
# if (
# search_settings_new.model_name == search_settings.model_name
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
# ):
# index_name += ALT_INDEX_SUFFIX
# search_values = search_settings_new.model_dump()
# search_values["index_name"] = index_name
# new_search_settings_request = SavedSearchSettings(**search_values)
# else:
# new_search_settings_request = SavedSearchSettings(
# **search_settings_new.model_dump()
# )
# Ensure the document indices have the new index immediately.
document_indices = get_all_document_indices(search_settings, new_search_settings)
for document_index in document_indices:
document_index.ensure_indices_exist(
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# secondary_search_settings = get_secondary_search_settings(db_session)
# Pause index attempts for the currently in-use index to preserve resources.
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
# if secondary_search_settings:
# # Cancel any background indexing jobs
# expire_index_attempts(
# search_settings_id=secondary_search_settings.id, db_session=db_session
# )
db_session.commit()
return IdReturn(id=new_search_settings.id)
# # Mark previous model as a past model directly
# update_search_settings_status(
# search_settings=secondary_search_settings,
# new_status=IndexModelStatus.PAST,
# db_session=db_session,
# )
# new_search_settings = create_search_settings(
# search_settings=new_search_settings_request, db_session=db_session
# )
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
# primary_embedding_precision=search_settings.embedding_precision,
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
# )
# # Pause index attempts for the currently in use index to preserve resources
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# expire_index_attempts(
# search_settings_id=search_settings.id, db_session=db_session
# )
# for cc_pair in get_connector_credential_pairs(db_session):
# resync_cc_pair(
# cc_pair=cc_pair,
# search_settings_id=new_search_settings.id,
# db_session=db_session,
# )
# db_session.commit()
# return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])

View File

@@ -1,5 +1,6 @@
import datetime
import json
import os
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -12,13 +13,13 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.cache.factory import get_cache_backend
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import convert_chat_history_basic
@@ -60,11 +61,13 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
from onyx.server.api_key_usage import check_api_key_usage
from onyx.server.query_and_chat.models import ChatFeedbackRequest
@@ -327,7 +330,7 @@ def get_chat_session(
]
try:
is_processing = is_chat_session_processing(session_id, get_cache_backend())
is_processing = is_chat_session_processing(session_id, get_redis_client())
# Edit the last message to indicate loading (Overriding default message value)
if is_processing and chat_message_details:
last_msg = chat_message_details[-1]
@@ -810,6 +813,18 @@ def fetch_chat_file(
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
original_file_name = file_record.display_name
if file_record.file_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
# Check if a converted text file exists for .docx files
txt_file_name = docx_to_txt_filename(original_file_name)
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
txt_file_record = file_store.read_file_record(txt_file_id)
if txt_file_record:
file_record = txt_file_record
file_id = txt_file_id
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")
@@ -912,10 +927,11 @@ async def search_chats(
def stop_chat_session(
chat_session_id: UUID,
user: User = Depends(current_user), # noqa: ARG001
redis_client: Redis = Depends(get_redis_client),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal.
Stop a chat session by setting a stop signal in Redis.
This endpoint is called by the frontend when the user clicks the stop button.
"""
set_fence(chat_session_id, get_cache_backend(), True)
set_fence(chat_session_id, redis_client, True)
return {"message": "Chat session stopped"}

View File

@@ -60,11 +60,9 @@ class Settings(BaseModel):
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -1,4 +1,3 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
@@ -7,8 +6,11 @@ from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -31,22 +33,30 @@ def load_settings() -> Settings:
logger.error(f"Error loading settings from KV store: {str(e)}")
settings = Settings()
cache = get_cache_backend()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
try:
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is not None:
assert isinstance(value, bytes)
anonymous_user_enabled = int(value.decode("utf-8")) == 1
else:
# Default to False
anonymous_user_enabled = False
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
# Optionally store the default back to Redis
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
)
except Exception as e:
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
# Log the error and reset to default
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
anonymous_user_enabled = False
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
# Override user knowledge setting if disabled via environment variable
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
@@ -56,10 +66,11 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
cache = get_cache_backend()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
if settings.anonymous_user_enabled is not None:
cache.set(
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
"1" if settings.anonymous_user_enabled else "0",
ex=SETTINGS_TTL,

View File

@@ -93,8 +93,6 @@ class ToolResponse(BaseModel):
# | WebContentResponse
# This comes from custom tools, tool result needs to be saved
| CustomToolCallSummary
# This comes from code interpreter, carries generated files
| PythonToolRichResponse
# If the rich response is a string, this is what's saved to the tool call in the DB
| str
| None # If nothing needs to be persisted outside of the string value passed to the LLM
@@ -195,12 +193,6 @@ class ChatFile(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class PythonToolRichResponse(BaseModel):
"""Rich response from the Python tool carrying generated files."""
generated_files: list[PythonExecutionFile] = []
class PythonToolOverrideKwargs(BaseModel):
"""Override kwargs for the Python/Code Interpreter tool."""
@@ -253,7 +245,6 @@ class ToolCallInfo(BaseModel):
tool_call_response: str
search_docs: list[SearchDoc] | None = None
generated_images: list[GeneratedImage] | None = None
generated_files: list[PythonExecutionFile] | None = None
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"

View File

@@ -1,7 +1,4 @@
from __future__ import annotations
import json
import time
from collections.abc import Generator
from typing import Literal
from typing import TypedDict
@@ -15,9 +12,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_HEALTH_CACHE_TTL_SECONDS = 30
_health_cache: dict[str, tuple[float, bool]] = {}
class FileInput(TypedDict):
"""Input file to be staged in execution workspace"""
@@ -86,19 +80,6 @@ class CodeInterpreterClient:
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
self._closed = False
def __enter__(self) -> CodeInterpreterClient:
return self
def __exit__(self, *args: object) -> None:
self.close()
def close(self) -> None:
if self._closed:
return
self.session.close()
self._closed = True
def _build_payload(
self,
@@ -117,32 +98,16 @@ class CodeInterpreterClient:
payload["files"] = files
return payload
def health(self, use_cache: bool = False) -> bool:
"""Check if the Code Interpreter service is healthy
Args:
use_cache: When True, return a cached result if available and
within the TTL window. The cache is always populated
after a live request regardless of this flag.
"""
if use_cache:
cached = _health_cache.get(self.base_url)
if cached is not None:
cached_at, cached_result = cached
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
return cached_result
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
result = response.json().get("status") == "ok"
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
result = False
_health_cache[self.base_url] = (time.monotonic(), result)
return result
return False
def execute(
self,
@@ -192,11 +157,8 @@ class CodeInterpreterClient:
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
try:
response.raise_for_status()
yield from self._parse_sse(response)
finally:
response.close()
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response

View File

@@ -23,7 +23,6 @@ from onyx.tools.interface import Tool
from onyx.tools.models import LlmPythonExecutionResult
from onyx.tools.models import PythonExecutionFile
from onyx.tools.models import PythonToolOverrideKwargs
from onyx.tools.models import PythonToolRichResponse
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.python.code_interpreter_client import (
@@ -108,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
if not server.server_enabled:
return False
with CodeInterpreterClient() as client:
return client.health(use_cache=True)
return server.server_enabled
def tool_definition(self) -> dict:
return {
@@ -176,203 +171,194 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
)
)
# Create Code Interpreter client — context manager ensures
# session.close() is called on every exit path.
with CodeInterpreterClient() as client:
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
logger.info(f"Staged file for Python execution: {file_name}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
# Create Code Interpreter client
client = CodeInterpreterClient()
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
logger.debug(f"Executing code: {code}")
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=(
event.data if event.stream == "stdout" else ""
),
stderr=(
event.data if event.stream == "stderr" else ""
),
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
logger.info(f"Staged file for Python execution: {file_name}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
try:
logger.debug(f"Executing code: {code}")
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file "
f"{workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated "
f"file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged "
f"file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=(None if result_event.exit_code == 0 else truncated_stderr),
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
# Emit error delta
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file {workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
obj=PythonToolDelta(file_ids=generated_file_ids),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
return ToolResponse(
rich_response=None, # No rich response needed for Python tool
llm_facing_response=llm_response,
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)

View File

@@ -1,5 +1,6 @@
import json
from typing import Any
from typing import cast
from sqlalchemy.orm import Session
from typing_extensions import override
@@ -56,30 +57,6 @@ def _sanitize_query(query: str) -> str:
return " ".join(sanitized.split())
def _normalize_queries_input(raw: Any) -> list[str]:
"""Coerce LLM output to a list of sanitized query strings.
Accepts a bare string or a list (possibly with non-string elements).
Sanitizes each query (strip control chars, normalize whitespace) and
drops empty or whitespace-only entries.
"""
if isinstance(raw, str):
raw = raw.strip()
if not raw:
return []
raw = [raw]
elif not isinstance(raw, list):
return []
result: list[str] = []
for q in raw:
if q is None:
continue
sanitized = _sanitize_query(str(q))
if sanitized:
result.append(sanitized)
return result
class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
NAME = "web_search"
DESCRIPTION = "Search the web for information."
@@ -212,7 +189,13 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
f'like: {{"queries": ["your search query here"]}}'
),
)
queries = _normalize_queries_input(llm_kwargs[QUERIES_FIELD])
raw_queries = cast(list[str], llm_kwargs[QUERIES_FIELD])
# Normalize queries:
# - remove control characters (null bytes, etc.) that LLMs sometimes produce
# - collapse whitespace and strip
# - drop empty/whitespace-only queries
queries = [sanitized for q in raw_queries if (sanitized := _sanitize_query(q))]
if not queries:
raise ToolCallException(
message=(

View File

@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
# typing-inspect
nest-asyncio==1.6.0
# via onyx
nltk==3.9.3
nltk==3.9.1
# via unstructured
numpy==2.4.1
# via

View File

@@ -16,6 +16,10 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
def run_jobs() -> None:
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
use_lightweight = True
# command setup
cmd_worker_primary = [
"celery",
"-A",
@@ -70,48 +74,6 @@ def run_jobs() -> None:
"--queues=connector_doc_fetching",
]
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,user_file_delete",
]
cmd_beat = [
"celery",
"-A",
@@ -120,31 +82,144 @@ def run_jobs() -> None:
"--loglevel=INFO",
]
all_workers = [
("PRIMARY", cmd_worker_primary),
("LIGHT", cmd_worker_light),
("DOCPROCESSING", cmd_worker_docprocessing),
("DOCFETCHING", cmd_worker_docfetching),
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
("BEAT", cmd_beat),
]
# Prepare background worker commands based on mode
if use_lightweight:
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
cmd_worker_background = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
]
background_workers = [("BACKGROUND", cmd_worker_background)]
else:
print("Starting workers in STANDARD mode (separate background workers)")
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
]
background_workers = [
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
]
processes = []
for name, cmd in all_workers:
# spawn processes
worker_primary_process = subprocess.Popen(
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_light_process = subprocess.Popen(
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_docprocessing_process = subprocess.Popen(
cmd_worker_docprocessing,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Spawn background worker processes based on mode
background_processes = []
for name, cmd in background_workers:
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
processes.append((name, process))
background_processes.append((name, process))
threads = []
for name, process in processes:
# monitor threads
worker_primary_thread = threading.Thread(
target=monitor_process, args=("PRIMARY", worker_primary_process)
)
worker_light_thread = threading.Thread(
target=monitor_process, args=("LIGHT", worker_light_process)
)
worker_docprocessing_thread = threading.Thread(
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
# Create monitor threads for background workers
background_threads = []
for name, process in background_processes:
thread = threading.Thread(target=monitor_process, args=(name, process))
threads.append(thread)
background_threads.append(thread)
# Start all threads
worker_primary_thread.start()
worker_light_thread.start()
worker_docprocessing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
for thread in background_threads:
thread.start()
for thread in threads:
# Wait for all threads
worker_primary_thread.join()
worker_light_thread.join()
worker_docprocessing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()
for thread in background_threads:
thread.join()

View File

@@ -1,20 +1,10 @@
#!/bin/bash
set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
stop_and_remove_containers() {
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
}
cleanup() {
echo "Error occurred. Cleaning up..."
stop_and_remove_containers
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -22,26 +12,16 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
# Usage of the script with optional volume arguments
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
# [minio_volume] [--keep-opensearch-data]
KEEP_OPENSEARCH_DATA=false
POSITIONAL_ARGS=()
for arg in "$@"; do
if [[ "$arg" == "--keep-opensearch-data" ]]; then
KEEP_OPENSEARCH_DATA=true
else
POSITIONAL_ARGS+=("$arg")
fi
done
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
VESPA_VOLUME=${1:-""} # Default is empty if not provided
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
REDIS_VOLUME=${3:-""} # Default is empty if not provided
MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
stop_and_remove_containers
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -59,29 +39,6 @@ else
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
fi
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
# .vscode/.env so existing dev setups that stored it there aren't silently
# broken.
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
set -a
# shellcheck source=/dev/null
source "$VSCODE_ENV"
set +a
fi
# Start the OpenSearch container using the same service from docker-compose that
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
# restarts, else the volume is deleted so the container starts fresh.
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
echo "Deleting opensearch-data volume..."
docker volume rm onyx_opensearch-data 2>/dev/null || true
fi
echo "Starting OpenSearch container..."
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
# Start the Redis container with optional volume
echo "Starting Redis container..."
if [[ -n "$REDIS_VOLUME" ]]; then
@@ -103,6 +60,7 @@ echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
cd "$PARENT_DIR"

View File

@@ -0,0 +1,10 @@
#!/bin/bash
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
source "$(dirname "$0")/../../.vscode/.env"
cd "$(dirname "$0")/../../deployment/docker_compose"
# Start OpenSearch.
echo "Forcefully starting fresh OpenSearch container..."
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch

View File

@@ -1,5 +1,23 @@
#!/bin/sh
# Entrypoint script for supervisord
# Entrypoint script for supervisord that sets environment variables
# for controlling which celery workers to start
# Default to lightweight mode if not set
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
fi
# Set the complementary variable for supervisord
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
export USE_SEPARATE_BACKGROUND_WORKERS="false"
else
export USE_SEPARATE_BACKGROUND_WORKERS="true"
fi
echo "Worker mode configuration:"
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
# Launch supervisord with environment variables available
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf

View File

@@ -39,6 +39,7 @@ autorestart=true
startsecs=10
stopasgroup=true
# Standard mode: Light worker for fast operations
# NOTE: only allowing configuration here and not in the other celery workers,
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
# user group syncing, deletion, etc.)
@@ -53,7 +54,26 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Lightweight mode: single consolidated background worker
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
[program:celery_worker_background]
command=celery -A onyx.background.celery.versioned_apps.background worker
--loglevel=INFO
--hostname=background@%%n
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
stdout_logfile=/var/log/celery_worker_background.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
# Standard mode: separate workers for different background tasks
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
[program:celery_worker_heavy]
command=celery -A onyx.background.celery.versioned_apps.heavy worker
--loglevel=INFO
@@ -65,7 +85,9 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document processing worker
[program:celery_worker_docprocessing]
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
--loglevel=INFO
@@ -77,6 +99,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_user_file_processing]
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
@@ -89,7 +112,9 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document fetching worker
[program:celery_worker_docfetching]
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
@@ -101,6 +126,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
@@ -113,6 +139,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Job scheduler for periodic tasks
@@ -170,6 +197,7 @@ command=tail -qF
/var/log/celery_beat.log
/var/log/celery_worker_primary.log
/var/log/celery_worker_light.log
/var/log/celery_worker_background.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_docprocessing.log
/var/log/celery_worker_monitoring.log

Some files were not shown because too many files have changed in this diff Show More