mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-08 01:02:39 +00:00
Compare commits
1 Commits
nikg/std-e
...
hide_colum
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9155a8767 |
@@ -15,7 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
51
.github/workflows/pr-integration-tests.yml
vendored
51
.github/workflows/pr-integration-tests.yml
vendored
@@ -471,13 +471,13 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
onyx-lite-tests:
|
||||
no-vectordb-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-onyx-lite-tests",
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
@@ -495,12 +495,13 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for Onyx Lite Docker Compose
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
@@ -508,23 +509,28 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for Onyx Lite (Postgres + API server)
|
||||
- name: Start Docker containers (onyx-lite)
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_onyx_lite
|
||||
id: start_docker_no_vectordb
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (onyx-lite)..."
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -546,14 +552,14 @@ jobs:
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run Onyx Lite Integration Tests
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running onyx-lite integration tests..."
|
||||
echo "Running no-vectordb integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -564,38 +570,39 @@ jobs:
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (onyx-lite)
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
|
||||
- name: Dump all-container logs (onyx-lite)
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
|
||||
- name: Upload logs (onyx-lite)
|
||||
- name: Upload logs (no-vectordb)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
|
||||
- name: Stop Docker containers (onyx-lite)
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
@@ -737,7 +744,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -617,45 +617,6 @@ Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -11,10 +11,11 @@ from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -141,7 +142,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from cache.
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
@@ -149,34 +150,38 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cached = cache.get(LICENSE_METADATA_KEY)
|
||||
if not cached:
|
||||
return None
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
try:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
Deletes the cached LicenseMetadata. The actual license in the database
|
||||
is not affected. Delete is idempotent — if the key doesn't exist, this
|
||||
is a no-op.
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cache.delete(LICENSE_METADATA_KEY)
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
@@ -187,7 +192,7 @@ def update_license_cache(
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the cache with license metadata.
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
@@ -206,7 +211,7 @@ def update_license_cache(
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
@@ -225,7 +230,7 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
cache.set(
|
||||
redis_client.set(
|
||||
LICENSE_METADATA_KEY,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
|
||||
@@ -126,16 +126,12 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str | None,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
@@ -274,13 +270,8 @@ class ScimDAL(DAL):
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -330,37 +321,34 @@ class ScimDAL(DAL):
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
|
||||
@@ -26,6 +26,7 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -41,6 +42,7 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
@@ -56,8 +58,6 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -169,23 +169,26 @@ async def create_checkout_session(
|
||||
if seats is not None:
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -203,15 +206,18 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -234,9 +240,9 @@ async def get_billing_information(
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE,
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -244,11 +250,11 @@ async def get_billing_information(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except OnyxError as e:
|
||||
except BillingServiceError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
_open_billing_circuit()
|
||||
raise
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -268,25 +274,31 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot reduce seats below current usage. "
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -317,18 +329,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -339,17 +351,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -33,6 +31,15 @@ logger = setup_logger()
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
@@ -94,7 +101,7 @@ async def _make_billing_request(
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
OnyxError: If request fails
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
@@ -121,17 +128,11 @@ async def _make_billing_request(
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
|
||||
)
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
|
||||
@@ -14,6 +14,7 @@ import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -34,8 +35,6 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -128,9 +127,9 @@ async def claim_license(
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License claiming is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,16 +146,15 @@ async def claim_license(
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found. Provide session_id after checkout.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found in database",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -175,7 +173,7 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -201,14 +199,12 @@ async def claim_license(
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -225,9 +221,9 @@ async def upload_license(
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License upload is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -238,14 +234,14 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -301,9 +297,9 @@ async def delete_license(
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License deletion is only available for self-hosted deployments",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -46,6 +46,7 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
@@ -55,7 +56,6 @@ from ee.onyx.configs.license_enforcement_config import (
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to cache connectivity issues
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
|
||||
@@ -423,63 +423,15 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -497,21 +449,21 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
# Create SCIM mapping when externalId is provided — this is how the IdP
|
||||
# correlates this user on subsequent requests. Per RFC 7643, externalId
|
||||
# is optional and assigned by the provisioning client.
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
if external_id:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
|
||||
@@ -170,10 +170,7 @@ class ScimProvider(ABC):
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
return ScimName(givenName="", familyName="", formatted="")
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
@@ -126,7 +125,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
# Fail closed - disable EE features if we can't verify license
|
||||
settings.ee_features_enabled = False
|
||||
|
||||
@@ -21,6 +21,7 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
@@ -42,8 +43,6 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -117,14 +116,9 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create customer portal session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -140,14 +134,9 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create checkout session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -158,20 +147,15 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create subscription session",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -202,18 +186,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -224,15 +208,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -120,6 +120,7 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -200,14 +201,13 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
transform_vespa_chunks_to_opensearch_chunks,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -48,7 +47,6 @@ from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -148,12 +146,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
@@ -168,7 +161,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
|
||||
@@ -520,7 +520,6 @@ def process_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
@@ -676,7 +675,6 @@ def delete_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
@@ -851,7 +849,6 @@ def project_sync_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
@@ -59,12 +59,6 @@ def _run_auto_llm_update() -> None:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
@@ -106,26 +100,12 @@ def _run_scheduled_eval() -> None:
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
|
||||
@@ -75,41 +75,31 @@ def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
).scalar_one_or_none()
|
||||
# Commit to release the row lock promptly.
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
stmt = (
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
@@ -123,10 +113,7 @@ def _claim_next_sync_file(
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
@@ -148,14 +135,11 @@ def drain_processing_loop(tenant_id: str) -> None:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process user file {file_id}")
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
@@ -165,21 +149,16 @@ def drain_delete_loop(tenant_id: str) -> None:
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
|
||||
file_id = _claim_next_deleting_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete user file {file_id}")
|
||||
failed.add(file_id)
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
@@ -189,18 +168,13 @@ def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session, exclude_ids=failed)
|
||||
file_id = _claim_next_sync_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
try:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to sync user file {file_id}")
|
||||
failed.add(file_id)
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
8
backend/onyx/cache/factory.py
vendored
8
backend/onyx/cache/factory.py
vendored
@@ -12,15 +12,9 @@ def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
# CacheBackendType.POSTGRES will be added in a follow-up PR.
|
||||
}
|
||||
|
||||
|
||||
|
||||
28
backend/onyx/cache/interface.py
vendored
28
backend/onyx/cache/interface.py
vendored
@@ -1,20 +1,6 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError)
|
||||
"""Exception types that represent transient cache connectivity / operational
|
||||
failures. Callers that want to fail-open (or fail-closed) on cache errors
|
||||
should catch this tuple instead of bare ``Exception``.
|
||||
|
||||
When adding a new ``CacheBackend`` implementation, add its transient error
|
||||
base class(es) here so all call-sites pick it up automatically."""
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
@@ -40,14 +26,6 @@ class CacheLock(abc.ABC):
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
if not self.acquire():
|
||||
raise RuntimeError("Failed to acquire lock")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
@@ -87,11 +65,7 @@ class CacheBackend(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds.
|
||||
|
||||
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
|
||||
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
|
||||
"""
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
323
backend/onyx/cache/postgres_backend.py
vendored
323
backend/onyx/cache/postgres_backend.py
vendored
@@ -1,323 +0,0 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
|
||||
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
|
||||
# lists whose names share a prefix (e.g. _q:mylist2:...).
|
||||
_LIST_KEY_RANGE_TERMINATOR = ";"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_LOCK_POLL_INTERVAL = 0.1
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
|
||||
collision when concurrent rpush calls occur within the same nanosecond.
|
||||
"""
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
|
||||
to the session's connection; releasing or closing the session frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the session is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._tenant_id = tenant_id
|
||||
self._session_cm: AbstractContextManager[Session] | None = None
|
||||
self._session: Session | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
|
||||
self._session = self._session_cm.__enter__()
|
||||
try:
|
||||
if not blocking:
|
||||
return self._try_lock()
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (
|
||||
(time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
)
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(_LOCK_POLL_INTERVAL)
|
||||
finally:
|
||||
if not self._acquired:
|
||||
self._close_session()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._session is None:
|
||||
return
|
||||
try:
|
||||
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._close_session()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _close_session(self) -> None:
|
||||
if self._session_cm is not None:
|
||||
try:
|
||||
self._session_cm.__exit__(None, None, None)
|
||||
finally:
|
||||
self._session_cm = None
|
||||
self._session = None
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._session is not None
|
||||
result = self._session.execute(
|
||||
select(func.pg_try_advisory_lock(self._lock_id))
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=ex)
|
||||
if ex is not None
|
||||
else None
|
||||
)
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return TTL_NO_EXPIRY
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(
|
||||
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
if timeout <= 0:
|
||||
raise ValueError(
|
||||
"PostgresCacheBackend.blpop requires timeout > 0. "
|
||||
"timeout=0 would block the calling thread indefinitely "
|
||||
"with no way to interrupt short of process termination."
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
@@ -1,52 +1,57 @@
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""Generate the cache key for a chat session processing fence.
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, cache: CacheBackend, value: bool
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""Set or clear the fence for a chat session processing a message.
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, a message is being processed.
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
cache.delete(fence_key)
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session is processing a message.
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
return cache.exists(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
|
||||
@@ -52,7 +52,6 @@ from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -967,13 +966,6 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Extract generated_files if this is a code interpreter response
|
||||
generated_files = None
|
||||
if isinstance(tool_response.rich_response, PythonToolRichResponse):
|
||||
generated_files = (
|
||||
tool_response.rich_response.generated_files or None
|
||||
)
|
||||
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
@@ -1025,7 +1017,6 @@ def run_llm_loop(
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
generated_files=generated_files,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
@@ -11,10 +11,9 @@ from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
@@ -80,6 +79,7 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
cache: CacheBackend | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
cache = get_cache_backend()
|
||||
redis_client = get_redis_client()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
cache,
|
||||
redis_client,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import mimetypes
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -13,41 +12,14 @@ from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.tools import create_tool_call_no_commit
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_referenced_file_descriptors(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
message_text: str,
|
||||
) -> list[FileDescriptor]:
|
||||
"""Extract FileDescriptors for code interpreter files referenced in the message text."""
|
||||
descriptors: list[FileDescriptor] = []
|
||||
for tool_call_info in tool_calls:
|
||||
if not tool_call_info.generated_files:
|
||||
continue
|
||||
for gen_file in tool_call_info.generated_files:
|
||||
file_id = (
|
||||
gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else ""
|
||||
)
|
||||
if file_id and file_id in message_text:
|
||||
mime_type, _ = mimetypes.guess_type(gen_file.filename)
|
||||
descriptors.append(
|
||||
FileDescriptor(
|
||||
id=file_id,
|
||||
type=mime_type_to_chat_file_type(mime_type),
|
||||
name=gen_file.filename,
|
||||
)
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -325,14 +297,5 @@ def save_chat_turn(
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# 8. Attach code interpreter generated files that the assistant actually
|
||||
# referenced in its response, so they are available via load_all_chat_files
|
||||
# on subsequent turns. Files not mentioned are intermediate artifacts.
|
||||
if message_text:
|
||||
referenced = _extract_referenced_file_descriptors(tool_calls, message_text)
|
||||
if referenced:
|
||||
existing_files = assistant_message.files or []
|
||||
assistant_message.files = existing_files + referenced
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,58 +1,65 @@
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""Generate the cache key for a chat session stop signal fence.
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
|
||||
"""Set or clear the stop signal fence for a chat session.
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
if not value:
|
||||
cache.delete(fence_key)
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session should continue (not stopped).
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
return not cache.exists(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
|
||||
"""Clear the stop signal for a chat session.
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
cache: Tenant-aware cache backend
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
"""
|
||||
cache.delete(_get_fence_key(chat_session_id))
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
@@ -819,9 +819,7 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
@@ -902,9 +900,6 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
|
||||
@@ -532,7 +532,6 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
|
||||
@@ -2822,17 +2822,8 @@ class LLMProvider(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Deprecated: use LLMModelFlow with CHAT flow type instead
|
||||
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -4926,9 +4917,7 @@ class ScimUserMapping(Base):
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
@@ -4985,25 +4974,3 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
@@ -19,7 +18,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -340,18 +338,12 @@ def get_all_chunks_paginated(
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
response: httpx.Response | None = None
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = (
|
||||
f"Failed to get chunks from Vespa slice {slice_id} with continuation token "
|
||||
f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds."
|
||||
)
|
||||
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
|
||||
@@ -52,9 +52,7 @@ def replace_invalid_doc_id_characters(text: str) -> str:
|
||||
return text.replace("'", "_")
|
||||
|
||||
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = True, timeout: int | None = None
|
||||
) -> httpx.Client:
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -66,7 +64,7 @@ def get_vespa_http_client(
|
||||
else None
|
||||
),
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT),
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
Standardized error codes for the Onyx backend.
|
||||
|
||||
Usage:
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired")
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OnyxErrorCode(Enum):
|
||||
"""
|
||||
Each member is a tuple of (error_code_string, http_status_code).
|
||||
|
||||
The error_code_string is a stable, machine-readable identifier that
|
||||
API consumers can match on. The http_status_code is the default HTTP
|
||||
status to return.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication (401)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHENTICATED = ("UNAUTHENTICATED", 401)
|
||||
INVALID_TOKEN = ("INVALID_TOKEN", 401)
|
||||
TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401)
|
||||
CSRF_FAILURE = ("CSRF_FAILURE", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorization (403)
|
||||
# ------------------------------------------------------------------
|
||||
UNAUTHORIZED = ("UNAUTHORIZED", 403)
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
# ------------------------------------------------------------------
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
# ------------------------------------------------------------------
|
||||
NOT_FOUND = ("NOT_FOUND", 404)
|
||||
CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404)
|
||||
CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404)
|
||||
PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404)
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
# ------------------------------------------------------------------
|
||||
CONFLICT = ("CONFLICT", 409)
|
||||
DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rate Limiting / Quotas (429 / 402)
|
||||
# ------------------------------------------------------------------
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400)
|
||||
CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400)
|
||||
CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server Errors (5xx)
|
||||
# ------------------------------------------------------------------
|
||||
INTERNAL_ERROR = ("INTERNAL_ERROR", 500)
|
||||
NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501)
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
|
||||
def detail(self, message: str | None = None) -> dict[str, str]:
|
||||
"""Build a structured error detail dict.
|
||||
|
||||
Returns a dict like:
|
||||
{"error_code": "UNAUTHENTICATED", "message": "Token expired"}
|
||||
|
||||
If no message is supplied, the error code itself is used as the message.
|
||||
"""
|
||||
return {
|
||||
"error_code": self.code,
|
||||
"message": message or self.code,
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
"""OnyxError — the single exception type for all Onyx business errors.
|
||||
|
||||
Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global
|
||||
FastAPI exception handler (registered via ``register_onyx_exception_handlers``)
|
||||
converts it into a JSON response with the standard
|
||||
``{"error_code": "...", "message": "..."}`` shape.
|
||||
|
||||
Usage::
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
For upstream errors with a dynamic HTTP status (e.g. billing service),
|
||||
use ``status_code_override``::
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=upstream_status,
|
||||
)
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxError(Exception):
|
||||
"""Structured error that maps to a specific ``OnyxErrorCode``.
|
||||
|
||||
Attributes:
|
||||
error_code: The ``OnyxErrorCode`` enum member.
|
||||
message: Human-readable message (defaults to the error code string).
|
||||
status_code: HTTP status — either overridden or from the error code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: OnyxErrorCode,
|
||||
message: str | None = None,
|
||||
*,
|
||||
status_code_override: int | None = None,
|
||||
) -> None:
|
||||
self.error_code = error_code
|
||||
self.message = message or error_code.code
|
||||
self._status_code_override = status_code_override
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._status_code_override or self.error_code.status_code
|
||||
|
||||
|
||||
def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register a global handler that converts ``OnyxError`` to JSON responses.
|
||||
|
||||
Must be called *after* the app is created but *before* it starts serving.
|
||||
The handler logs at WARNING for 4xx and ERROR for 5xx.
|
||||
"""
|
||||
|
||||
@app.exception_handler(OnyxError)
|
||||
async def _handle_onyx_error(
|
||||
request: Request, # noqa: ARG001
|
||||
exc: OnyxError,
|
||||
) -> JSONResponse:
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.message}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.message),
|
||||
)
|
||||
@@ -4,33 +4,39 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key prefix for OAuth state
|
||||
OAUTH_STATE_PREFIX = "federated_oauth"
|
||||
OAUTH_STATE_TTL = 300 # 5 minutes
|
||||
# Default TTL for OAuth state (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
|
||||
|
||||
class OAuthSession:
|
||||
"""Represents an OAuth session stored in the cache backend."""
|
||||
"""Represents an OAuth session stored in Redis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.federated_connector_id = federated_connector_id
|
||||
self.user_id = user_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.additional_data = additional_data or {}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for Redis storage."""
|
||||
return {
|
||||
"federated_connector_id": self.federated_connector_id,
|
||||
"user_id": self.user_id,
|
||||
@@ -39,7 +45,8 @@ class OAuthSession:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
|
||||
"""Create from dictionary retrieved from Redis."""
|
||||
return cls(
|
||||
federated_connector_id=data["federated_connector_id"],
|
||||
user_id=data["user_id"],
|
||||
@@ -51,27 +58,31 @@ class OAuthSession:
|
||||
def generate_oauth_state(
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
ttl: int = OAUTH_STATE_TTL,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a secure state parameter and store session data in the cache backend.
|
||||
Generate a secure state parameter and store session data in Redis.
|
||||
|
||||
Args:
|
||||
federated_connector_id: ID of the federated connector
|
||||
user_id: ID of the user initiating OAuth
|
||||
redirect_uri: Optional redirect URI after OAuth completion
|
||||
additional_data: Any additional data to store with the session
|
||||
ttl: Time-to-live in seconds for the cache key
|
||||
ttl: Time-to-live in seconds for the Redis key
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter
|
||||
"""
|
||||
# Generate a random UUID for the state
|
||||
state_uuid = uuid.uuid4()
|
||||
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Convert UUID to base64 for URL-safe state parameter
|
||||
state_bytes = state_uuid.bytes
|
||||
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Create session object
|
||||
session = OAuthSession(
|
||||
federated_connector_id=federated_connector_id,
|
||||
user_id=user_id,
|
||||
@@ -79,9 +90,15 @@ def generate_oauth_state(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
|
||||
# Store in Redis with TTL
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
json.dumps(session.to_dict()),
|
||||
ex=ttl,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
|
||||
@@ -108,15 +125,18 @@ def verify_oauth_state(state: str) -> OAuthSession:
|
||||
state_bytes = base64.urlsafe_b64decode(padded_state)
|
||||
state_uuid = uuid.UUID(bytes=state_bytes)
|
||||
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
# Look up in Redis
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
session_data = cache.get(cache_key)
|
||||
session_data = cast(bytes, redis_client.get(redis_key))
|
||||
if not session_data:
|
||||
raise ValueError(f"OAuth state not found: {state}")
|
||||
raise ValueError(f"OAuth state not found in Redis: {state}")
|
||||
|
||||
cache.delete(cache_key)
|
||||
# Delete the key after retrieval (one-time use)
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# Parse and return session
|
||||
session_dict = json.loads(session_data)
|
||||
return OAuthSession.from_dict(session_dict)
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -18,27 +20,22 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, cache: CacheBackend | None = None) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _get_cache(self) -> CacheBackend:
|
||||
if self._cache is None:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
self._cache = get_cache_backend()
|
||||
return self._cache
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
try:
|
||||
self._get_cache().set(
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Cache backend fails
|
||||
logger.error(
|
||||
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
|
||||
)
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
@@ -56,12 +53,16 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
|
||||
if not refresh_cache:
|
||||
try:
|
||||
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
|
||||
if cached is not None:
|
||||
return json.loads(cached.decode("utf-8"))
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
if not isinstance(redis_value, bytes):
|
||||
raise ValueError(
|
||||
f"Redis value for key '{key}' is not a bytes object"
|
||||
)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get value from cache for key '{key}': {str(e)}"
|
||||
f"Failed to get value from Redis for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -78,21 +79,21 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self._get_cache().set(
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self._get_cache().delete(REDIS_KEY_PREFIX + key)
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
|
||||
@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -13,38 +13,44 @@ from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
|
||||
# Redis key for caching the last updated timestamp (per-tenant)
|
||||
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
|
||||
|
||||
def _get_cached_last_updated_at() -> datetime | None:
|
||||
"""Get the cached last_updated_at timestamp from Redis."""
|
||||
try:
|
||||
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
if value is not None:
|
||||
redis_client = get_redis_client()
|
||||
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
if value and isinstance(value, bytes):
|
||||
# Value is bytes, decode to string then parse as ISO format
|
||||
return datetime.fromisoformat(value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cached last_updated_at: {e}")
|
||||
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_last_updated_at(updated_at: datetime) -> None:
|
||||
"""Set the cached last_updated_at timestamp in Redis."""
|
||||
try:
|
||||
get_cache_backend().set(
|
||||
_CACHE_KEY_LAST_UPDATED_AT,
|
||||
redis_client = get_redis_client()
|
||||
# Store as ISO format string, with 24 hour expiration
|
||||
redis_client.set(
|
||||
_REDIS_KEY_LAST_UPDATED_AT,
|
||||
updated_at.isoformat(),
|
||||
ex=_CACHE_TTL_SECONDS,
|
||||
ex=60 * 60 * 24, # 24 hours
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set cached last_updated_at: {e}")
|
||||
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
|
||||
|
||||
|
||||
def fetch_llm_recommendations_from_github(
|
||||
@@ -142,8 +148,9 @@ def sync_llm_models_from_github(
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the cache timestamp. Useful for testing."""
|
||||
"""Reset the cache timestamp in Redis. Useful for testing."""
|
||||
try:
|
||||
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
redis_client = get_redis_client()
|
||||
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset cache: {e}")
|
||||
logger.warning(f"Failed to reset cache in Redis: {e}")
|
||||
|
||||
@@ -59,7 +59,6 @@ from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
@@ -445,8 +444,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error
|
||||
)
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -92,7 +92,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -573,43 +572,6 @@ def _normalize_file_names_for_backwards_compatibility(
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
require_editable: bool,
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=require_editable,
|
||||
)
|
||||
if has_requested_access:
|
||||
return cc_pair
|
||||
|
||||
# Special case: global curators should be able to manage files
|
||||
# for public file connectors even when they are not the creator.
|
||||
if (
|
||||
require_editable
|
||||
and user.role == UserRole.GLOBAL_CURATOR
|
||||
and cc_pair.access_type == AccessType.PUBLIC
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -621,7 +583,7 @@ def upload_files_api(
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
@@ -634,13 +596,6 @@ def list_connector_files(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=False,
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
@@ -690,7 +645,7 @@ def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
@@ -708,13 +663,12 @@ def update_connector_files(
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
# and validate user permissions for file management.
|
||||
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=True,
|
||||
)
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
|
||||
@@ -1133,8 +1133,7 @@ done
|
||||
# Already deleted
|
||||
service_deleted = True
|
||||
else:
|
||||
logger.error(f"Error deleting Service {service_name}: {e}")
|
||||
raise
|
||||
logger.warning(f"Error deleting Service {service_name}: {e}")
|
||||
|
||||
pod_deleted = False
|
||||
try:
|
||||
@@ -1149,8 +1148,7 @@ done
|
||||
# Already deleted
|
||||
pod_deleted = True
|
||||
else:
|
||||
logger.error(f"Error deleting Pod {pod_name}: {e}")
|
||||
raise
|
||||
logger.warning(f"Error deleting Pod {pod_name}: {e}")
|
||||
|
||||
# Wait for resources to be fully deleted to prevent 409 conflicts
|
||||
# on immediate re-provisioning
|
||||
|
||||
@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Prevent overlapping runs of this task
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
llm = get_default_llm()
|
||||
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -8,10 +8,10 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.cache.factory import get_shared_cache_backend
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
@@ -113,46 +113,60 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
etag = cache.get(REDIS_KEY_ETAG)
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8")
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag: {e}")
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
raw = cache.get(REDIS_KEY_FETCHED_AT)
|
||||
if not raw:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
return None
|
||||
|
||||
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from cache: {e}")
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
cache = get_shared_cache_backend()
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to cache: {e}")
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
@@ -182,10 +196,11 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
cache = get_shared_cache_backend()
|
||||
lock = cache.lock(
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -10,8 +11,6 @@ from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.search_settings import get_current_db_embedding_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.indexing.models import EmbeddingModelDetail
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -60,7 +59,7 @@ def test_embedding_configuration(
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||
@@ -94,9 +93,8 @@ def delete_embedding_provider(
|
||||
embedding_provider is not None
|
||||
and provider_type == embedding_provider.provider_type
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"You can't delete a currently active model",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, provider_type=provider_type)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -14,8 +15,6 @@ from onyx.db.llm import remove_llm_provider__no_commit
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
@@ -75,9 +74,9 @@ def _build_llm_provider_request(
|
||||
# Clone mode: Only use API key from source provider
|
||||
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -111,9 +110,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
@@ -125,9 +124,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not validate_credentials(provider, credentials):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Incorrect credentials for {provider}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Incorrect credentials for {provider}",
|
||||
)
|
||||
|
||||
return LLMProviderUpsertRequest(
|
||||
@@ -216,9 +215,9 @@ def test_image_generation(
|
||||
LLMProviderModel, test_request.source_llm_provider_id
|
||||
)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -237,9 +236,9 @@ def test_image_generation(
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -258,14 +257,14 @@ def test_image_generation(
|
||||
),
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Invalid image generation provider: {provider}",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Invalid image generation provider: {provider}",
|
||||
)
|
||||
except ImageProviderCredentialsError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
"Invalid image generation credentials",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid image generation credentials",
|
||||
)
|
||||
|
||||
quality = _get_test_quality_for_model(test_request.model_name)
|
||||
@@ -277,15 +276,15 @@ def test_image_generation(
|
||||
n=1,
|
||||
quality=quality,
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log only exception type to avoid exposing sensitive data
|
||||
# (LiteLLM errors may contain URLs with API keys or auth tokens)
|
||||
logger.warning(f"Image generation test failed: {type(e).__name__}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Image generation test failed: {type(e).__name__}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image generation test failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,9 +309,9 @@ def create_config(
|
||||
db_session, config_create.image_provider_id
|
||||
)
|
||||
if existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -346,10 +345,10 @@ def create_config(
|
||||
db_session.commit()
|
||||
db_session.refresh(config)
|
||||
return ImageGenerationConfigView.from_model(config)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/config")
|
||||
@@ -374,9 +373,9 @@ def get_config_credentials(
|
||||
"""
|
||||
config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
return ImageGenerationCredentials.from_model(config)
|
||||
@@ -402,9 +401,9 @@ def update_config(
|
||||
# 1. Get existing config
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -473,10 +472,10 @@ def update_config(
|
||||
db_session.refresh(existing_config)
|
||||
return ImageGenerationConfigView.from_model(existing_config)
|
||||
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}")
|
||||
@@ -490,9 +489,9 @@ def delete_config(
|
||||
# Get the config first to find the associated LLM provider
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -504,10 +503,10 @@ def delete_config(
|
||||
remove_llm_provider__no_commit(db_session, llm_provider_id)
|
||||
|
||||
db_session.commit()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/config/{image_provider_id}/default")
|
||||
@@ -520,7 +519,7 @@ def set_config_as_default(
|
||||
try:
|
||||
set_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}/default")
|
||||
@@ -533,4 +532,4 @@ def unset_config_as_default(
|
||||
try:
|
||||
unset_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -11,6 +11,7 @@ from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -37,8 +38,6 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import user_can_access_persona
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
@@ -187,7 +186,7 @@ def _validate_llm_provider_change(
|
||||
Only enforced in MULTI_TENANT mode.
|
||||
|
||||
Raises:
|
||||
OnyxError: If api_base or custom_config changed without changing API key
|
||||
HTTPException: If api_base or custom_config changed without changing API key
|
||||
"""
|
||||
if not MULTI_TENANT or api_key_changed:
|
||||
return
|
||||
@@ -201,9 +200,9 @@ def _validate_llm_provider_change(
|
||||
)
|
||||
|
||||
if api_base_changed or custom_config_changed:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base and/or custom config cannot be changed without changing the API key",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base and/or custom config cannot be changed without changing the API key",
|
||||
)
|
||||
|
||||
|
||||
@@ -223,7 +222,7 @@ def fetch_llm_provider_options(
|
||||
for well_known_llm in well_known_llms:
|
||||
if well_known_llm.name == provider_name:
|
||||
return well_known_llm
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
@@ -282,7 +281,7 @@ def test_llm_configuration(
|
||||
error_msg = test_llm(llm)
|
||||
|
||||
if error_msg:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -293,11 +292,11 @@ def test_default_provider(
|
||||
llm = get_default_llm()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
|
||||
raise HTTPException(status_code=400, detail=str(error))
|
||||
|
||||
|
||||
@admin_router.get("/provider")
|
||||
@@ -363,31 +362,35 @@ def put_llm_provider(
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Renaming providers is not currently supported",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -412,9 +415,9 @@ def put_llm_provider(
|
||||
db_session, persona_ids
|
||||
)
|
||||
if missing_personas:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
)
|
||||
# Remove duplicates while preserving order
|
||||
seen: set[int] = set()
|
||||
@@ -470,29 +473,19 @@ def put_llm_provider(
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
force: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
try:
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@@ -532,9 +525,9 @@ def get_auto_config(
|
||||
"""
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Failed to fetch configuration from GitHub",
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch configuration from GitHub",
|
||||
)
|
||||
return config.model_dump()
|
||||
|
||||
@@ -691,13 +684,13 @@ def list_llm_providers_for_persona(
|
||||
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
# Verify user has access to this persona
|
||||
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You don't have access to this assistant",
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have access to this assistant",
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
@@ -851,9 +844,9 @@ def get_bedrock_available_models(
|
||||
try:
|
||||
bedrock = session.client("bedrock")
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
)
|
||||
|
||||
# Build model info dict from foundation models (modelId -> metadata)
|
||||
@@ -972,14 +965,14 @@ def get_bedrock_available_models(
|
||||
return results
|
||||
|
||||
except (ClientError, NoCredentialsError, BotoCoreError) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to connect to AWS Bedrock: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to AWS Bedrock: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"Unexpected error fetching Bedrock models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unexpected error fetching Bedrock models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -991,9 +984,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch Ollama models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
@@ -1010,9 +1003,9 @@ def get_ollama_available_models(
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch Ollama models.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base URL is required to fetch Ollama models.",
|
||||
)
|
||||
|
||||
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
|
||||
@@ -1021,9 +1014,9 @@ def get_ollama_available_models(
|
||||
# with the same response format
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Ollama server",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
@@ -1125,9 +1118,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch OpenRouter models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch OpenRouter models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -1148,9 +1141,9 @@ def get_openrouter_available_models(
|
||||
|
||||
data = response_json.get("data", [])
|
||||
if not isinstance(data, list) or len(data) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenRouter endpoint",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your OpenRouter endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenRouterFinalModelResponse] = []
|
||||
@@ -1185,9 +1178,8 @@ def get_openrouter_available_models(
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenRouter",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No compatible models found from OpenRouter"
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -19,8 +21,6 @@ from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
@@ -48,9 +48,9 @@ def set_new_search_settings(
|
||||
# NOTE Enable integration external dependency tests in test_search_settings.py
|
||||
# when this is reenabled. They are currently skipped
|
||||
logger.error("Setting new search settings is temporarily disabled.")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
"Setting new search settings is temporarily disabled.",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Setting new search settings is temporarily disabled.",
|
||||
)
|
||||
# if search_settings_new.index_name:
|
||||
# logger.warning("Index name was specified by request, this is not suggested")
|
||||
@@ -191,7 +191,7 @@ def delete_search_settings_endpoint(
|
||||
search_settings_id=deletion_request.search_settings_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
@@ -241,9 +241,9 @@ def update_saved_search_settings(
|
||||
) -> None:
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -297,7 +297,7 @@ def validate_contextual_rag_model(
|
||||
model_name=model_name,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
|
||||
|
||||
|
||||
def _validate_contextual_rag_model(
|
||||
|
||||
@@ -13,13 +13,13 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import convert_chat_history_basic
|
||||
@@ -67,6 +67,7 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
@@ -329,7 +330,7 @@ def get_chat_session(
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_cache_backend())
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
@@ -926,10 +927,11 @@ async def search_chats(
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(current_user), # noqa: ARG001
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal.
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -7,8 +6,11 @@ from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -31,22 +33,30 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
cache = get_cache_backend()
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
# Override user knowledge setting if disabled via environment variable
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
@@ -56,10 +66,11 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
cache = get_cache_backend()
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
cache.set(
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
|
||||
@@ -93,8 +93,6 @@ class ToolResponse(BaseModel):
|
||||
# | WebContentResponse
|
||||
# This comes from custom tools, tool result needs to be saved
|
||||
| CustomToolCallSummary
|
||||
# This comes from code interpreter, carries generated files
|
||||
| PythonToolRichResponse
|
||||
# If the rich response is a string, this is what's saved to the tool call in the DB
|
||||
| str
|
||||
| None # If nothing needs to be persisted outside of the string value passed to the LLM
|
||||
@@ -195,12 +193,6 @@ class ChatFile(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PythonToolRichResponse(BaseModel):
|
||||
"""Rich response from the Python tool carrying generated files."""
|
||||
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
|
||||
|
||||
class PythonToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for the Python/Code Interpreter tool."""
|
||||
|
||||
@@ -253,7 +245,6 @@ class ToolCallInfo(BaseModel):
|
||||
tool_call_response: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
generated_files: list[PythonExecutionFile] | None = None
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -13,9 +12,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -102,32 +98,16 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
result = response.json().get("status") == "ok"
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import LlmPythonExecutionResult
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
@@ -108,11 +107,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -330,9 +325,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
rich_response=None, # No rich response needed for Python tool
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
@@ -56,30 +57,6 @@ def _sanitize_query(query: str) -> str:
|
||||
return " ".join(sanitized.split())
|
||||
|
||||
|
||||
def _normalize_queries_input(raw: Any) -> list[str]:
|
||||
"""Coerce LLM output to a list of sanitized query strings.
|
||||
|
||||
Accepts a bare string or a list (possibly with non-string elements).
|
||||
Sanitizes each query (strip control chars, normalize whitespace) and
|
||||
drops empty or whitespace-only entries.
|
||||
"""
|
||||
if isinstance(raw, str):
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
return []
|
||||
raw = [raw]
|
||||
elif not isinstance(raw, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for q in raw:
|
||||
if q is None:
|
||||
continue
|
||||
sanitized = _sanitize_query(str(q))
|
||||
if sanitized:
|
||||
result.append(sanitized)
|
||||
return result
|
||||
|
||||
|
||||
class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
NAME = "web_search"
|
||||
DESCRIPTION = "Search the web for information."
|
||||
@@ -212,7 +189,13 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
f'like: {{"queries": ["your search query here"]}}'
|
||||
),
|
||||
)
|
||||
queries = _normalize_queries_input(llm_kwargs[QUERIES_FIELD])
|
||||
raw_queries = cast(list[str], llm_kwargs[QUERIES_FIELD])
|
||||
|
||||
# Normalize queries:
|
||||
# - remove control characters (null bytes, etc.) that LLMs sometimes produce
|
||||
# - collapse whitespace and strip
|
||||
# - drop empty/whitespace-only queries
|
||||
queries = [sanitized for q in raw_queries if (sanitized := _sanitize_query(q))]
|
||||
if not queries:
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
|
||||
@@ -13,11 +13,9 @@ the correct files.
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
@@ -57,32 +55,6 @@ def _create_user_file(
|
||||
return uf
|
||||
|
||||
|
||||
def _fake_delete_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: delete the row so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _fake_sync_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: clear sync flags so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == UUID(user_file_id))
|
||||
.values(needs_project_sync=False, needs_persona_sync=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
@@ -153,9 +125,9 @@ class TestRecoverDeletingFiles:
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
# Row is deleted by _fake_delete_impl, so no cleanup needed.
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_delete_impl)
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -183,7 +155,7 @@ class TestRecoverSyncFiles:
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -207,7 +179,7 @@ class TestRecoverSyncFiles:
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -245,108 +217,3 @@ class TestRecoveryMultipleFiles:
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestTransientFailures:
|
||||
"""Drain loops skip failed files, process the rest, and terminate."""
|
||||
|
||||
def test_processing_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_proc")
|
||||
uf_fail = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_delete_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_del")
|
||||
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf_fail)
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_sync_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_sync")
|
||||
uf_fail = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
uf_ok = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
|
||||
|
||||
Verifies that the KV store correctly uses the CacheBackend for caching
|
||||
in front of PostgreSQL: cache hits, cache misses falling through to PG,
|
||||
cache population after PG reads, cache invalidation on delete, and
|
||||
graceful degradation when the cache backend raises.
|
||||
|
||||
Requires running PostgreSQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import CacheStore
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.key_value_store.store import REDIS_KEY_PREFIX
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_kv() -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(KVStore))
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
|
||||
return PgRedisKVStore(cache=pg_cache)
|
||||
|
||||
|
||||
class TestStoreAndLoad:
|
||||
def test_store_populates_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k1", {"hello": "world"})
|
||||
|
||||
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
|
||||
assert cached is not None
|
||||
assert json.loads(cached) == {"hello": "world"}
|
||||
|
||||
loaded = kv_store.load("k1")
|
||||
assert loaded == {"hello": "world"}
|
||||
|
||||
def test_load_returns_cached_value_without_pg_hit(
|
||||
self, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
"""If the cache already has the value, PG should not be queried."""
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
|
||||
kv = PgRedisKVStore(cache=pg_cache)
|
||||
assert kv.load("cached_only") == {"from": "cache"}
|
||||
|
||||
def test_load_falls_through_to_pg_on_cache_miss(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k2", [1, 2, 3])
|
||||
|
||||
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
|
||||
|
||||
loaded = kv_store.load("k2")
|
||||
assert loaded == [1, 2, 3]
|
||||
|
||||
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
|
||||
assert repopulated is not None
|
||||
assert json.loads(repopulated) == [1, 2, 3]
|
||||
|
||||
def test_load_with_refresh_cache_skips_cache(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k3", "original")
|
||||
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
|
||||
|
||||
loaded = kv_store.load("k3", refresh_cache=True)
|
||||
assert loaded == "original"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_from_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("del_me", "bye")
|
||||
kv_store.delete("del_me")
|
||||
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
|
||||
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.load("del_me")
|
||||
|
||||
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.delete("nonexistent")
|
||||
|
||||
|
||||
class TestCacheFailureGracefulDegradation:
|
||||
def test_store_succeeds_when_cache_set_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("resilient", {"data": True})
|
||||
|
||||
working_cache = MagicMock(spec=CacheBackend)
|
||||
working_cache.get.return_value = None
|
||||
kv_reader = PgRedisKVStore(cache=working_cache)
|
||||
loaded = kv_reader.load("resilient")
|
||||
assert loaded == {"data": True}
|
||||
|
||||
def test_load_falls_through_when_cache_get_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.get.side_effect = ConnectionError("cache down")
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("survive", 42)
|
||||
loaded = kv.load("survive")
|
||||
assert loaded == 42
|
||||
@@ -1,229 +0,0 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
stmt = select(CacheStore.key).where(CacheStore.key == k)
|
||||
with get_session_with_current_tenant() as session:
|
||||
row = session.execute(stmt).first()
|
||||
assert row is None, "expired row should be physically deleted"
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
@@ -19,8 +20,6 @@ from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.server.manage.llm.api import (
|
||||
@@ -123,16 +122,16 @@ class TestLLMConfigurationEndpoint:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
def test_failed_llm_test_raises_onyx_error(
|
||||
def test_failed_llm_test_raises_http_exception(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str, # noqa: ARG002
|
||||
) -> None:
|
||||
"""
|
||||
Test that a failed LLM test raises an OnyxError with VALIDATION_ERROR.
|
||||
Test that a failed LLM test raises an HTTPException with status 400.
|
||||
|
||||
When test_llm returns an error message, the endpoint should raise
|
||||
an OnyxError with the error details.
|
||||
an HTTPException with the error details.
|
||||
"""
|
||||
error_message = "Invalid API key: Authentication failed"
|
||||
|
||||
@@ -144,7 +143,7 @@ class TestLLMConfigurationEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -157,8 +156,9 @@ class TestLLMConfigurationEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
# Verify the exception details
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -536,11 +536,11 @@ class TestDefaultProviderEndpoint:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
# Now run_test_default_provider should fail
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -581,11 +581,11 @@ class TestDefaultProviderEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -16,14 +16,13 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.api import _mask_string
|
||||
from onyx.server.manage.llm.api import put_llm_provider
|
||||
@@ -101,7 +100,7 @@ class TestLLMProviderChanges:
|
||||
api_base="https://attacker.example.com",
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -109,9 +108,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -237,7 +236,7 @@ class TestLLMProviderChanges:
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -245,9 +244,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -340,7 +339,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -348,9 +347,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -376,7 +375,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -384,9 +383,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1027,13 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self._capture("GET", b"")
|
||||
if self.path == "/health":
|
||||
self._respond_json(200, {"status": "ok"})
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
self._capture("DELETE", b"")
|
||||
self.send_response(200)
|
||||
@@ -1114,14 +1107,6 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
|
||||
"""Ensure the default persona (id=0) has the PythonTool attached."""
|
||||
|
||||
@@ -114,8 +114,8 @@ def test_create_duplicate_config_fails(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_get_all_configs(
|
||||
@@ -292,7 +292,7 @@ def test_update_config_source_provider_not_found(
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["message"]
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_config(
|
||||
@@ -468,7 +468,7 @@ def test_create_config_missing_credentials(
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "No provider or source llm provided" in response.json()["message"]
|
||||
assert "No provider or source llm provided" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_create_config_source_provider_not_found(
|
||||
@@ -488,4 +488,4 @@ def test_create_config_source_provider_not_found(
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["message"]
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
@@ -386,261 +386,6 @@ def test_delete_llm_provider(
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Deleting the default LLM provider should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-default-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete the default provider — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
|
||||
|
||||
def test_delete_non_default_llm_provider_with_default_set(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a non-default provider should succeed even when a default is set."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create two providers
|
||||
response_default = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "default-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_default.status_code == 200
|
||||
default_provider = response_default.json()
|
||||
|
||||
response_other = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "other-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_other.status_code == 200
|
||||
other_provider = response_other.json()
|
||||
|
||||
# Set the first provider as default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": default_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Delete the non-default provider — should succeed
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the non-default provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, other_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify the default provider still exists
|
||||
default_data = _get_provider_by_id(admin_user, default_provider["id"])
|
||||
assert default_data is not None
|
||||
|
||||
|
||||
def test_force_delete_default_llm_provider(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Force-deleting the default LLM provider should succeed."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-force-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete without force — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
|
||||
# Force delete — should succeed
|
||||
force_delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert force_delete_response.status_code == 200
|
||||
|
||||
# Verify provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_vision_provider_clears_vision_default(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting the default vision provider should succeed and clear the vision default."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a text provider and set it as default (so we have a default text provider)
|
||||
text_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "text-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert text_response.status_code == 200
|
||||
text_provider = text_response.json()
|
||||
_set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini")
|
||||
|
||||
# Create a vision provider and set it as default vision
|
||||
vision_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "vision-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000002",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
supports_image_input=True,
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert vision_response.status_code == 200
|
||||
vision_provider = vision_response.json()
|
||||
_set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o")
|
||||
|
||||
# Verify vision default is set
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, _, vision_default = _unpack_data(data)
|
||||
assert vision_default is not None
|
||||
assert vision_default["provider_id"] == vision_provider["id"]
|
||||
|
||||
# Delete the vision provider — should succeed (only text default is protected)
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the vision provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, vision_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify there is no default vision provider
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, text_default, vision_default = _unpack_data(data)
|
||||
assert vision_default is None
|
||||
|
||||
# Verify the text default is still intact
|
||||
assert text_default is not None
|
||||
assert text_default["provider_id"] == text_provider["id"]
|
||||
|
||||
|
||||
def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Creating a provider with a name that already exists should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
@@ -673,8 +418,8 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
headers=admin_user.headers,
|
||||
json=base_payload,
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
@@ -711,7 +456,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "not currently supported" in response.json()["message"]
|
||||
assert "not currently supported" in response.json()["detail"]
|
||||
|
||||
# Verify no duplicate was created — only the original provider should exist
|
||||
provider = _get_provider_by_id(admin_user, provider_id)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
|
||||
|
||||
# Should return 403 Forbidden
|
||||
assert response.status_code == 403
|
||||
assert "don't have access to this assistant" in response.json()["message"]
|
||||
assert "don't have access to this assistant" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_authorized_persona_access_returns_filtered_providers(
|
||||
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
|
||||
|
||||
# Should return 404
|
||||
assert response.status_code == 404
|
||||
assert "Persona not found" in response.json()["message"]
|
||||
assert "Persona not found" in response.json()["detail"]
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def _upload_connector_file(
|
||||
*,
|
||||
user_performing_action: DATestUser,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
) -> tuple[str, str]:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
|
||||
files=[("files", (file_name, io.BytesIO(content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return payload["file_paths"][0], payload["file_names"][0]
|
||||
|
||||
|
||||
def _update_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
file_ids_to_remove: list[str],
|
||||
new_file_name: str,
|
||||
new_file_content: bytes,
|
||||
) -> requests.Response:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files/update",
|
||||
data={"file_ids_to_remove": json.dumps(file_ids_to_remove)},
|
||||
files=[("files", (new_file_name, io.BytesIO(new_file_content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def _list_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
@pytest.mark.usefixtures("reset")
|
||||
def test_only_global_curator_can_update_public_file_connector_files() -> None:
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
global_curator_creator = UserManager.create(name="global_curator_creator")
|
||||
global_curator_creator = UserManager.set_role(
|
||||
user_to_set=global_curator_creator,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
global_curator_editor = UserManager.create(name="global_curator_editor")
|
||||
global_curator_editor = UserManager.set_role(
|
||||
user_to_set=global_curator_editor,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
curator_user = UserManager.create(name="curator_user")
|
||||
curator_group = UserGroupManager.create(
|
||||
name="curator_group",
|
||||
user_ids=[curator_user.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[curator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=curator_group,
|
||||
user_to_set_as_curator=curator_user,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
initial_file_id, initial_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="initial-file.txt",
|
||||
content=b"initial file content",
|
||||
)
|
||||
|
||||
connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="public_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [initial_file_id],
|
||||
"file_names": [initial_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name="public_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
name="public_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
)
|
||||
curator_list_response.raise_for_status()
|
||||
curator_list_payload = curator_list_response.json()
|
||||
assert any(f["file_id"] == initial_file_id for f in curator_list_payload["files"])
|
||||
|
||||
global_curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
)
|
||||
global_curator_list_response.raise_for_status()
|
||||
global_curator_list_payload = global_curator_list_response.json()
|
||||
assert any(
|
||||
f["file_id"] == initial_file_id for f in global_curator_list_payload["files"]
|
||||
)
|
||||
|
||||
denied_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="curator-file.txt",
|
||||
new_file_content=b"curator updated file",
|
||||
)
|
||||
assert denied_response.status_code == 403
|
||||
|
||||
allowed_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="global-curator-file.txt",
|
||||
new_file_content=b"global curator updated file",
|
||||
)
|
||||
allowed_response.raise_for_status()
|
||||
|
||||
payload = allowed_response.json()
|
||||
assert initial_file_id not in payload["file_paths"]
|
||||
assert "global-curator-file.txt" in payload["file_names"]
|
||||
|
||||
creator_group = UserGroupManager.create(
|
||||
name="creator_group",
|
||||
user_ids=[global_curator_creator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[creator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
private_file_id, private_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="private-initial-file.txt",
|
||||
content=b"private initial file content",
|
||||
)
|
||||
|
||||
private_connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="private_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [private_file_id],
|
||||
"file_names": [private_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
)
|
||||
private_credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=private_connector.id,
|
||||
credential_id=private_credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
private_denied_response = _update_connector_files(
|
||||
connector_id=private_connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[private_file_id],
|
||||
new_file_name="global-curator-private-file.txt",
|
||||
new_file_content=b"global curator private update",
|
||||
)
|
||||
assert private_denied_response.status_code == 403
|
||||
@@ -300,7 +300,7 @@ def test_update_contextual_rag_nonexistent_provider(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider nonexistent-provider not found" in response.json()["message"]
|
||||
assert "Provider nonexistent-provider not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_contextual_rag_nonexistent_model(
|
||||
@@ -322,7 +322,7 @@ def test_update_contextual_rag_nonexistent_model(
|
||||
assert response.status_code == 400
|
||||
assert (
|
||||
f"Model nonexistent-model not found in provider {llm_provider.name}"
|
||||
in response.json()["message"]
|
||||
in response.json()["detail"]
|
||||
)
|
||||
|
||||
|
||||
@@ -342,7 +342,7 @@ def test_update_contextual_rag_missing_provider_name(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider name and model name are required" in response.json()["message"]
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_contextual_rag_missing_model_name(
|
||||
@@ -362,7 +362,7 @@ def test_update_contextual_rag_missing_model_name(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider name and model name are required" in response.json()["message"]
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
|
||||
@@ -11,8 +11,7 @@ from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
|
||||
|
||||
class TestCreateCheckoutSession:
|
||||
@@ -89,25 +88,22 @@ class TestCreateCheckoutSession:
|
||||
mock_get_tenant: MagicMock,
|
||||
mock_service: AsyncMock,
|
||||
) -> None:
|
||||
"""Should propagate OnyxError when service fails."""
|
||||
"""Should raise HTTPException when service fails."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import create_checkout_session
|
||||
|
||||
mock_get_license.return_value = None
|
||||
mock_get_tenant.return_value = "tenant_123"
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Stripe error",
|
||||
status_code_override=502,
|
||||
)
|
||||
mock_service.side_effect = BillingServiceError("Stripe error", 502)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_checkout_session(
|
||||
request=None, _=MagicMock(), db_session=MagicMock()
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Stripe error"
|
||||
assert "Stripe error" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateCustomerPortalSession:
|
||||
@@ -125,19 +121,20 @@ class TestCreateCustomerPortalSession:
|
||||
mock_service: AsyncMock, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Should reject self-hosted without license."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import create_customer_portal_session
|
||||
|
||||
mock_get_license.return_value = None
|
||||
mock_get_tenant.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_customer_portal_session(
|
||||
request=None, _=MagicMock(), db_session=MagicMock()
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert "No license found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.create_portal_service")
|
||||
@@ -230,6 +227,8 @@ class TestUpdateSeats:
|
||||
mock_get_tenant: MagicMock,
|
||||
) -> None:
|
||||
"""Should reject self-hosted without license."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import update_seats
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
|
||||
@@ -238,12 +237,11 @@ class TestUpdateSeats:
|
||||
|
||||
request = SeatUpdateRequest(new_seat_count=10)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == "No license found"
|
||||
assert "No license found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.get_used_seats")
|
||||
@@ -297,27 +295,26 @@ class TestUpdateSeats:
|
||||
mock_service: AsyncMock,
|
||||
mock_get_used_seats: MagicMock,
|
||||
) -> None:
|
||||
"""Should propagate OnyxError from service layer."""
|
||||
"""Should convert BillingServiceError to HTTPException."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import update_seats
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_get_used_seats.return_value = 0
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Cannot reduce below 10 seats",
|
||||
status_code_override=400,
|
||||
mock_service.side_effect = BillingServiceError(
|
||||
"Cannot reduce below 10 seats", 400
|
||||
)
|
||||
|
||||
request = SeatUpdateRequest(new_seat_count=5)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_seats(request=request, _=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.message == "Cannot reduce below 10 seats"
|
||||
assert "Cannot reduce below 10 seats" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@@ -335,18 +332,19 @@ class TestCircuitBreaker:
|
||||
mock_circuit_open: MagicMock,
|
||||
) -> None:
|
||||
"""Should return 503 when circuit breaker is open."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open.return_value = True
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SERVICE_UNAVAILABLE
|
||||
assert "Connect to Stripe" in exc_info.value.message
|
||||
assert "Connect to Stripe" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.api.MULTI_TENANT", False)
|
||||
@@ -364,18 +362,16 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 502 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Connection failed",
|
||||
status_code_override=502,
|
||||
)
|
||||
mock_service.side_effect = BillingServiceError("Connection failed", 502)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
@@ -397,18 +393,16 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 503 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Service unavailable",
|
||||
status_code_override=503,
|
||||
)
|
||||
mock_service.side_effect = BillingServiceError("Service unavailable", 503)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
@@ -430,18 +424,16 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should open circuit breaker on 504 error."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Gateway timeout",
|
||||
status_code_override=504,
|
||||
)
|
||||
mock_service.side_effect = BillingServiceError("Gateway timeout", 504)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 504
|
||||
@@ -463,18 +455,16 @@ class TestCircuitBreaker:
|
||||
mock_open_circuit: MagicMock,
|
||||
) -> None:
|
||||
"""Should NOT open circuit breaker on 400 error (client error)."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.billing.api import get_billing_information
|
||||
|
||||
mock_get_license.return_value = "license_blob"
|
||||
mock_get_tenant.return_value = None
|
||||
mock_circuit_open_check.return_value = False
|
||||
mock_service.side_effect = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Bad request",
|
||||
status_code_override=400,
|
||||
)
|
||||
mock_service.side_effect = BillingServiceError("Bad request", 400)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_billing_information(_=MagicMock(), db_session=MagicMock())
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
@@ -14,8 +14,7 @@ from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
|
||||
|
||||
class TestMakeBillingRequest:
|
||||
@@ -79,7 +78,7 @@ class TestMakeBillingRequest:
|
||||
mock_base_url: MagicMock,
|
||||
mock_headers: MagicMock,
|
||||
) -> None:
|
||||
"""Should raise OnyxError on HTTP error."""
|
||||
"""Should raise BillingServiceError on HTTP error."""
|
||||
from ee.onyx.server.billing.service import _make_billing_request
|
||||
|
||||
mock_base_url.return_value = "https://api.example.com"
|
||||
@@ -92,7 +91,7 @@ class TestMakeBillingRequest:
|
||||
mock_client = make_mock_http_client("post", side_effect=error)
|
||||
|
||||
with patch("httpx.AsyncClient", mock_client):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(BillingServiceError) as exc_info:
|
||||
await _make_billing_request(
|
||||
method="POST",
|
||||
path="/test",
|
||||
@@ -100,7 +99,6 @@ class TestMakeBillingRequest:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Bad request" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -138,7 +136,7 @@ class TestMakeBillingRequest:
|
||||
mock_base_url: MagicMock,
|
||||
mock_headers: MagicMock,
|
||||
) -> None:
|
||||
"""Should raise OnyxError on connection error."""
|
||||
"""Should raise BillingServiceError on connection error."""
|
||||
from ee.onyx.server.billing.service import _make_billing_request
|
||||
|
||||
mock_base_url.return_value = "https://api.example.com"
|
||||
@@ -147,11 +145,10 @@ class TestMakeBillingRequest:
|
||||
mock_client = make_mock_http_client("post", side_effect=error)
|
||||
|
||||
with patch("httpx.AsyncClient", mock_client):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(BillingServiceError) as exc_info:
|
||||
await _make_billing_request(method="POST", path="/test")
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code is OnyxErrorCode.BAD_GATEWAY
|
||||
assert "Failed to connect" in exc_info.value.message
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,6 @@ from unittest.mock import patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
class TestGetStripePublishableKey:
|
||||
"""Tests for get_stripe_publishable_key endpoint."""
|
||||
@@ -65,14 +62,15 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_rejects_invalid_env_var_key_format(self) -> None:
|
||||
"""Should reject keys that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -82,6 +80,8 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_rejects_invalid_s3_key_format(self) -> None:
|
||||
"""Should reject keys from S3 that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
mock_response = MagicMock()
|
||||
@@ -92,12 +92,11 @@ class TestGetStripePublishableKey:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Invalid Stripe publishable key format"
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@@ -107,32 +106,34 @@ class TestGetStripePublishableKey:
|
||||
)
|
||||
async def test_handles_s3_fetch_error(self) -> None:
|
||||
"""Should return error when S3 fetch fails."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
side_effect=httpx.HTTPError("Connection failed")
|
||||
)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert exc_info.value.message == "Failed to fetch Stripe publishable key"
|
||||
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL", None)
|
||||
async def test_error_when_no_config(self) -> None:
|
||||
"""Should return error when neither env var nor S3 URL is configured."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
|
||||
assert "not configured" in exc_info.value.message
|
||||
assert "not configured" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
"""Tests for _extract_referenced_file_descriptors in save_chat.py.
|
||||
|
||||
Verifies that only code interpreter generated files actually referenced
|
||||
in the assistant's message text are extracted as FileDescriptors for
|
||||
cross-turn persistence.
|
||||
"""
|
||||
|
||||
from onyx.chat.save_chat import _extract_referenced_file_descriptors
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
|
||||
|
||||
def _make_tool_call_info(
|
||||
generated_files: list[PythonExecutionFile] | None = None,
|
||||
tool_name: str = "python",
|
||||
) -> ToolCallInfo:
|
||||
return ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=0,
|
||||
tab_index=0,
|
||||
tool_name=tool_name,
|
||||
tool_call_id="tc_1",
|
||||
tool_id=1,
|
||||
reasoning_tokens=None,
|
||||
tool_call_arguments={"code": "print('hi')"},
|
||||
tool_call_response="{}",
|
||||
generated_files=generated_files,
|
||||
)
|
||||
|
||||
|
||||
def test_returns_empty_when_no_generated_files() -> None:
|
||||
tool_call = _make_tool_call_info(generated_files=None)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "some message")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_returns_empty_when_file_not_referenced() -> None:
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link="http://localhost/api/chat/file/abc-123",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
result = _extract_referenced_file_descriptors([tool_call], "Here is your answer.")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extracts_referenced_file() -> None:
|
||||
file_id = "abc-123-def"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = (
|
||||
f"Here is the chart: [chart.png](http://localhost/api/chat/file/{file_id})"
|
||||
)
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
assert result[0]["type"] == ChatFileType.IMAGE
|
||||
assert result[0]["name"] == "chart.png"
|
||||
|
||||
|
||||
def test_filters_unreferenced_files() -> None:
|
||||
referenced_id = "ref-111"
|
||||
unreferenced_id = "unref-222"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="chart.png",
|
||||
file_link=f"http://localhost/api/chat/file/{referenced_id}",
|
||||
),
|
||||
PythonExecutionFile(
|
||||
filename="data.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{unreferenced_id}",
|
||||
),
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"Here is the chart: [chart.png](http://localhost/api/chat/file/{referenced_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == referenced_id
|
||||
assert result[0]["name"] == "chart.png"
|
||||
|
||||
|
||||
def test_extracts_from_multiple_tool_calls() -> None:
|
||||
id_1 = "file-aaa"
|
||||
id_2 = "file-bbb"
|
||||
tc1 = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="plot.png",
|
||||
file_link=f"http://localhost/api/chat/file/{id_1}",
|
||||
)
|
||||
]
|
||||
)
|
||||
tc2 = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="report.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{id_2}",
|
||||
)
|
||||
]
|
||||
)
|
||||
message = (
|
||||
f"[plot.png](http://localhost/api/chat/file/{id_1}) "
|
||||
f"and [report.csv](http://localhost/api/chat/file/{id_2})"
|
||||
)
|
||||
|
||||
result = _extract_referenced_file_descriptors([tc1, tc2], message)
|
||||
|
||||
assert len(result) == 2
|
||||
ids = {d["id"] for d in result}
|
||||
assert ids == {id_1, id_2}
|
||||
|
||||
|
||||
def test_csv_file_type() -> None:
|
||||
file_id = "csv-123"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="data.csv",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"[data.csv](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.CSV
|
||||
|
||||
|
||||
def test_unknown_extension_defaults_to_plain_text() -> None:
|
||||
file_id = "bin-456"
|
||||
files = [
|
||||
PythonExecutionFile(
|
||||
filename="output.xyz",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
tool_call = _make_tool_call_info(generated_files=files)
|
||||
message = f"[output.xyz](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.PLAIN_TEXT
|
||||
|
||||
|
||||
def test_skips_tool_calls_without_generated_files() -> None:
|
||||
file_id = "img-789"
|
||||
tc_no_files = _make_tool_call_info(generated_files=None)
|
||||
tc_empty = _make_tool_call_info(generated_files=[])
|
||||
tc_with_files = _make_tool_call_info(
|
||||
generated_files=[
|
||||
PythonExecutionFile(
|
||||
filename="result.png",
|
||||
file_link=f"http://localhost/api/chat/file/{file_id}",
|
||||
)
|
||||
]
|
||||
)
|
||||
message = f"[result.png](http://localhost/api/chat/file/{file_id})"
|
||||
|
||||
result = _extract_referenced_file_descriptors(
|
||||
[tc_no_files, tc_empty, tc_with_files], message
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == file_id
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Unit tests for stop_signal_checker and chat_processing_checker.
|
||||
|
||||
These modules are safety-critical — they control whether a chat stream
|
||||
continues or stops. The tests use a simple in-memory CacheBackend stub
|
||||
so no external services are needed.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.stop_signal_checker import FENCE_TTL
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None, # noqa: ARG002
|
||||
) -> None:
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── stop_signal_checker ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetFence:
|
||||
def test_set_fence_true_creates_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_removes_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_noop_when_absent(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_uses_ttl(self) -> None:
|
||||
"""Verify set_fence passes ex=FENCE_TTL to cache.set."""
|
||||
calls: list[dict[str, object]] = []
|
||||
cache = _MemoryCacheBackend()
|
||||
original_set = cache.set
|
||||
|
||||
def tracking_set(
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
calls.append({"key": key, "ex": ex})
|
||||
original_set(key, value, ex=ex)
|
||||
|
||||
cache.set = tracking_set # type: ignore[method-assign]
|
||||
|
||||
set_fence(uuid4(), cache, True)
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["ex"] == FENCE_TTL
|
||||
|
||||
|
||||
class TestIsConnected:
|
||||
def test_connected_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert is_connected(uuid4(), cache)
|
||||
|
||||
def test_disconnected_when_fence_set(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_fence(sid1, cache, True)
|
||||
assert not is_connected(sid1, cache)
|
||||
assert is_connected(sid2, cache)
|
||||
|
||||
|
||||
class TestResetCancelStatus:
|
||||
def test_clears_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
reset_cancel_status(sid, cache)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_noop_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
reset_cancel_status(uuid4(), cache)
|
||||
|
||||
|
||||
# ── chat_processing_checker ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetProcessingStatus:
|
||||
def test_set_true_marks_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
assert is_chat_session_processing(sid, cache)
|
||||
|
||||
def test_set_false_clears_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
set_processing_status(sid, cache, False)
|
||||
assert not is_chat_session_processing(sid, cache)
|
||||
|
||||
|
||||
class TestIsChatSessionProcessing:
|
||||
def test_not_processing_by_default(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert not is_chat_session_processing(uuid4(), cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_processing_status(sid1, cache, True)
|
||||
assert is_chat_session_processing(sid1, cache)
|
||||
assert not is_chat_session_processing(sid2, cache)
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Tests for OnyxError and the global exception handler."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
|
||||
|
||||
class TestOnyxError:
|
||||
"""Unit tests for OnyxError construction and properties."""
|
||||
|
||||
def test_basic_construction(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
assert err.error_code is OnyxErrorCode.NOT_FOUND
|
||||
assert err.message == "Session not found"
|
||||
assert err.status_code == 404
|
||||
|
||||
def test_message_defaults_to_code(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
assert err.message == "UNAUTHENTICATED"
|
||||
assert str(err) == "UNAUTHENTICATED"
|
||||
|
||||
def test_status_code_override(self) -> None:
|
||||
err = OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"upstream failed",
|
||||
status_code_override=503,
|
||||
)
|
||||
assert err.status_code == 503
|
||||
# error_code still reports its own default
|
||||
assert err.error_code.status_code == 502
|
||||
|
||||
def test_no_override_uses_error_code_status(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.RATE_LIMITED, "slow down")
|
||||
assert err.status_code == 429
|
||||
|
||||
def test_is_exception(self) -> None:
|
||||
err = OnyxError(OnyxErrorCode.INTERNAL_ERROR)
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
|
||||
class TestExceptionHandler:
|
||||
"""Integration test: OnyxError → JSON response via FastAPI TestClient."""
|
||||
|
||||
@pytest.fixture()
|
||||
def client(self) -> TestClient:
|
||||
app = FastAPI()
|
||||
register_onyx_exception_handlers(app)
|
||||
|
||||
@app.get("/boom")
|
||||
def _boom() -> None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Thing not found")
|
||||
|
||||
@app.get("/boom-override")
|
||||
def _boom_override() -> None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"upstream 503",
|
||||
status_code_override=503,
|
||||
)
|
||||
|
||||
@app.get("/boom-default-msg")
|
||||
def _boom_default() -> None:
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_returns_correct_status_and_body(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "NOT_FOUND"
|
||||
assert body["message"] == "Thing not found"
|
||||
|
||||
def test_status_code_override_in_response(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-override")
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "BAD_GATEWAY"
|
||||
assert body["message"] == "upstream 503"
|
||||
|
||||
def test_default_message(self, client: TestClient) -> None:
|
||||
resp = client.get("/boom-default-msg")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "UNAUTHENTICATED"
|
||||
assert body["message"] == "UNAUTHENTICATED"
|
||||
@@ -1,163 +0,0 @@
|
||||
"""Unit tests for federated OAuth state generation and verification.
|
||||
|
||||
Uses unittest.mock to patch get_cache_backend so no external services
|
||||
are needed. Verifies the generate -> verify round-trip, one-time-use
|
||||
semantics, TTL propagation, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.federated_connectors.oauth_utils import generate_oauth_state
|
||||
from onyx.federated_connectors.oauth_utils import OAUTH_STATE_TTL
|
||||
from onyx.federated_connectors.oauth_utils import OAuthSession
|
||||
from onyx.federated_connectors.oauth_utils import verify_oauth_state
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
self.set_calls: list[dict[str, object]] = []
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self.set_calls.append({"key": key, "ex": ex})
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _patched(cache: _MemoryCacheBackend): # type: ignore[no-untyped-def]
|
||||
return patch(
|
||||
"onyx.federated_connectors.oauth_utils.get_cache_backend",
|
||||
return_value=cache,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateAndVerifyRoundTrip:
|
||||
def test_round_trip_basic(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=42,
|
||||
user_id="user-abc",
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 42
|
||||
assert session.user_id == "user-abc"
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
|
||||
def test_round_trip_with_all_fields(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=7,
|
||||
user_id="user-xyz",
|
||||
redirect_uri="https://example.com/callback",
|
||||
additional_data={"scope": "read"},
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 7
|
||||
assert session.user_id == "user-xyz"
|
||||
assert session.redirect_uri == "https://example.com/callback"
|
||||
assert session.additional_data == {"scope": "read"}
|
||||
|
||||
|
||||
class TestOneTimeUse:
|
||||
def test_verify_deletes_state(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
verify_oauth_state(state)
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestTTLPropagation:
|
||||
def test_default_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
|
||||
assert len(cache.set_calls) == 1
|
||||
assert cache.set_calls[0]["ex"] == OAUTH_STATE_TTL
|
||||
|
||||
def test_custom_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u", ttl=600)
|
||||
|
||||
assert cache.set_calls[0]["ex"] == 600
|
||||
|
||||
|
||||
class TestVerifyInvalidState:
|
||||
def test_missing_state_raises(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
# Manually clear the cache to simulate expiration
|
||||
cache._store.clear()
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestOAuthSessionSerialization:
|
||||
def test_to_dict_from_dict_round_trip(self) -> None:
|
||||
session = OAuthSession(
|
||||
federated_connector_id=5,
|
||||
user_id="u-123",
|
||||
redirect_uri="https://redir.example.com",
|
||||
additional_data={"key": "val"},
|
||||
)
|
||||
d = session.to_dict()
|
||||
restored = OAuthSession.from_dict(d)
|
||||
|
||||
assert restored.federated_connector_id == 5
|
||||
assert restored.user_id == "u-123"
|
||||
assert restored.redirect_uri == "https://redir.example.com"
|
||||
assert restored.additional_data == {"key": "val"}
|
||||
|
||||
def test_from_dict_defaults(self) -> None:
|
||||
minimal = {"federated_connector_id": 1, "user_id": "u"}
|
||||
session = OAuthSession.from_dict(minimal)
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
@@ -117,10 +117,7 @@ class TestOktaProvider:
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
# Falls back to deriving name from email local part
|
||||
assert result.name == ScimName(
|
||||
givenName="test", familyName="", formatted="test"
|
||||
)
|
||||
assert result.name == ScimName(givenName="", familyName="", formatted="")
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
|
||||
@@ -215,7 +215,7 @@ class TestCreateUser:
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_missing_external_id_still_creates_mapping(
|
||||
def test_missing_external_id_creates_user_without_mapping(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
@@ -223,7 +223,6 @@ class TestCreateUser:
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Mapping is always created to mark user as SCIM-managed."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
@@ -237,11 +236,11 @@ class TestCreateUser:
|
||||
parsed = parse_scim_user(result, status=201)
|
||||
assert parsed.userName is not None
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
mock_dal.create_user_mapping.assert_not_called()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_duplicate_scim_managed_email_returns_409(
|
||||
def test_duplicate_email_returns_409(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
@@ -249,12 +248,7 @@ class TestCreateUser:
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""409 only when the existing user already has a SCIM mapping."""
|
||||
existing = make_db_user()
|
||||
mock_dal.get_user_by_email.return_value = existing
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = make_user_mapping(
|
||||
user_id=existing.id
|
||||
)
|
||||
mock_dal.get_user_by_email.return_value = make_db_user()
|
||||
resource = make_scim_user()
|
||||
|
||||
result = create_user(
|
||||
@@ -266,40 +260,6 @@ class TestCreateUser:
|
||||
|
||||
assert_scim_error(result, 409)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_existing_user_without_mapping_gets_linked(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Pre-existing user without SCIM mapping gets adopted (linked)."""
|
||||
existing = make_db_user(email="admin@example.com", personal_name=None)
|
||||
mock_dal.get_user_by_email.return_value = existing
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
resource = make_scim_user(userName="admin@example.com", externalId="ext-admin")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_user(result, status=201)
|
||||
assert parsed.userName == "admin@example.com"
|
||||
# Should NOT create a new user — reuse existing
|
||||
mock_dal.add_user.assert_not_called()
|
||||
# Should sync is_active and personal_name from the SCIM request
|
||||
mock_dal.update_user.assert_called_once_with(
|
||||
existing, is_active=True, personal_name="Test User"
|
||||
)
|
||||
# Should create a SCIM mapping for the existing user
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_integrity_error_returns_409(
|
||||
self,
|
||||
|
||||
@@ -1,37 +1,25 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag and health check.
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database, or
|
||||
- The Code Interpreter service health check fails.
|
||||
|
||||
Also verifies that the health check result is cached with a TTL.
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
|
||||
CLIENT_MODULE = "onyx.tools.tool_implementations.python.code_interpreter_client"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", None)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -39,7 +27,10 @@ def test_python_tool_unavailable_without_base_url() -> None:
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "")
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -52,8 +43,13 @@ def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
@@ -68,15 +64,18 @@ def test_python_tool_unavailable_when_server_disabled(
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check determines availability when URL + server are OK
|
||||
# Available when both conditions are met
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_available_when_health_check_passes(
|
||||
mock_client_cls: MagicMock,
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
@@ -85,120 +84,5 @@ def test_python_tool_available_when_health_check_passes(
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_unavailable_when_health_check_fails(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check is NOT reached when preconditions fail
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_health_check_not_called_when_server_disabled(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client_cls.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check caching (tested at the client level)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_health_check_cached_on_second_call() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health(use_cache=True) is True
|
||||
assert client.health(use_cache=True) is True
|
||||
# Only one HTTP call — the second used the cache
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{CLIENT_MODULE}.time")
|
||||
def test_health_check_refreshed_after_ttl_expires(mock_time: MagicMock) -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
_HEALTH_CACHE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
# First call at t=0 — cache miss
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Second call within TTL — cache hit
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS - 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Third call after TTL — cache miss, fresh request
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS + 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_health_check_no_cache_by_default() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health() is True
|
||||
assert client.health() is True
|
||||
# Both calls hit the network when use_cache=False (default)
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
_normalize_queries_input,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
|
||||
|
||||
def _make_result(
|
||||
title: str = "Title", link: str = "https://example.com"
|
||||
) -> WebSearchResult:
|
||||
return WebSearchResult(title=title, link=link, snippet="snippet")
|
||||
|
||||
|
||||
def _make_tool(mock_provider: Any) -> WebSearchTool:
|
||||
"""Instantiate WebSearchTool with all DB/provider deps mocked out."""
|
||||
provider_model = MagicMock()
|
||||
provider_model.provider_type = "brave"
|
||||
provider_model.api_key = MagicMock()
|
||||
provider_model.api_key.get_value.return_value = "fake-key"
|
||||
provider_model.config = {}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_session_with_current_tenant"
|
||||
) as mock_session_ctx,
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.fetch_active_web_search_provider",
|
||||
return_value=provider_model,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.build_search_provider_from_config",
|
||||
return_value=mock_provider,
|
||||
),
|
||||
):
|
||||
mock_session_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
tool = WebSearchTool(tool_id=1, emitter=MagicMock())
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
def _run(tool: WebSearchTool, queries: Any) -> list[str]:
|
||||
"""Call tool.run() and return the list of query strings passed to provider.search."""
|
||||
placement = Placement(turn_index=0, tab_index=0)
|
||||
override_kwargs = WebSearchToolOverrideKwargs(starting_citation_num=1)
|
||||
tool.run(placement=placement, override_kwargs=override_kwargs, queries=queries)
|
||||
search_mock = cast(MagicMock, tool._provider.search) # noqa: SLF001
|
||||
return [call.args[0] for call in search_mock.call_args_list]
|
||||
|
||||
|
||||
class TestNormalizeQueriesInput:
|
||||
"""Unit tests for _normalize_queries_input (coercion + sanitization)."""
|
||||
|
||||
def test_bare_string_returns_single_element_list(self) -> None:
|
||||
assert _normalize_queries_input("hello") == ["hello"]
|
||||
|
||||
def test_bare_string_stripped_and_sanitized(self) -> None:
|
||||
assert _normalize_queries_input(" hello ") == ["hello"]
|
||||
# Control chars (e.g. null) removed; no space inserted
|
||||
assert _normalize_queries_input("hello\x00world") == ["helloworld"]
|
||||
|
||||
def test_empty_string_returns_empty_list(self) -> None:
|
||||
assert _normalize_queries_input("") == []
|
||||
assert _normalize_queries_input(" ") == []
|
||||
|
||||
def test_list_of_strings_returned_sanitized(self) -> None:
|
||||
assert _normalize_queries_input(["a", "b"]) == ["a", "b"]
|
||||
# Leading/trailing space stripped; control chars (e.g. tab) removed
|
||||
assert _normalize_queries_input([" a ", "b\tb"]) == ["a", "bb"]
|
||||
|
||||
def test_list_none_skipped(self) -> None:
|
||||
assert _normalize_queries_input(["a", None, "b"]) == ["a", "b"]
|
||||
|
||||
def test_list_non_string_coerced(self) -> None:
|
||||
assert _normalize_queries_input([1, "two"]) == ["1", "two"]
|
||||
|
||||
def test_list_whitespace_only_dropped(self) -> None:
|
||||
assert _normalize_queries_input(["a", "", " ", "b"]) == ["a", "b"]
|
||||
|
||||
def test_non_list_non_string_returns_empty_list(self) -> None:
|
||||
assert _normalize_queries_input(42) == []
|
||||
assert _normalize_queries_input({}) == []
|
||||
|
||||
|
||||
class TestWebSearchToolRunQueryCoercion:
|
||||
def test_list_of_strings_dispatches_each_query(self) -> None:
|
||||
"""Normal case: list of queries → one search call per query."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, ["python decorators", "python generators"])
|
||||
|
||||
# run_functions_tuples_in_parallel uses a thread pool; call_args_list order is non-deterministic.
|
||||
assert sorted(dispatched) == ["python decorators", "python generators"]
|
||||
|
||||
def test_bare_string_dispatches_as_single_query(self) -> None:
|
||||
"""LLM returns a bare string instead of an array — must NOT be split char-by-char."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, "what is the capital of France")
|
||||
|
||||
assert len(dispatched) == 1
|
||||
assert dispatched[0] == "what is the capital of France"
|
||||
|
||||
def test_bare_string_does_not_search_individual_characters(self) -> None:
|
||||
"""Regression: single-char searches must not occur."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, "hi")
|
||||
for query_arg in dispatched:
|
||||
assert (
|
||||
len(query_arg) > 1
|
||||
), f"Single-character query dispatched: {query_arg!r}"
|
||||
|
||||
def test_control_characters_sanitized_before_dispatch(self) -> None:
|
||||
"""Queries with control chars have those chars removed before dispatch."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.search.return_value = [_make_result()]
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
|
||||
dispatched = _run(tool, ["foo\x00bar", "baz\tbaz"])
|
||||
|
||||
# run_functions_tuples_in_parallel uses a thread pool; call_args_list is in
|
||||
# execution order, not submission order, so compare in sorted order.
|
||||
assert sorted(dispatched) == ["bazbaz", "foobar"]
|
||||
|
||||
def test_all_empty_or_whitespace_raises_tool_call_exception(self) -> None:
|
||||
"""When normalization yields no valid queries, run() raises ToolCallException."""
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_site_filter = False
|
||||
tool = _make_tool(mock_provider)
|
||||
placement = Placement(turn_index=0, tab_index=0)
|
||||
override_kwargs = WebSearchToolOverrideKwargs(starting_citation_num=1)
|
||||
|
||||
with pytest.raises(ToolCallException) as exc_info:
|
||||
tool.run(
|
||||
placement=placement,
|
||||
override_kwargs=override_kwargs,
|
||||
queries=" ",
|
||||
)
|
||||
|
||||
assert "No valid" in str(exc_info.value)
|
||||
cast(MagicMock, mock_provider.search).assert_not_called()
|
||||
@@ -126,9 +126,7 @@ Resources:
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- secretsmanager:GetSecretValue
|
||||
Resource:
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
|
||||
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
|
||||
Outputs:
|
||||
OutputEcsCluster:
|
||||
|
||||
@@ -167,12 +167,10 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: basic
|
||||
Value: disabled
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
|
||||
@@ -166,11 +166,9 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: basic
|
||||
Value: disabled
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
@@ -1,32 +1,30 @@
|
||||
# =============================================================================
|
||||
# ONYX LITE — MINIMAL DEPLOYMENT OVERLAY
|
||||
# ONYX NO-VECTOR-DB OVERLAY
|
||||
# =============================================================================
|
||||
# Overlay to run Onyx in a minimal configuration: no vector database (Vespa),
|
||||
# no Redis, no model servers, and no background workers. Only PostgreSQL is
|
||||
# required. In this mode, connectors and RAG search are disabled, but the core
|
||||
# chat experience (LLM conversations, tools, user file uploads, Projects,
|
||||
# Agent knowledge, code interpreter) still works.
|
||||
# Overlay to run Onyx without a vector database (Vespa), model servers, or
|
||||
# code interpreter. In this mode, connectors and RAG search are disabled, but
|
||||
# the core chat experience (LLM conversations, tools, user file uploads,
|
||||
# Projects, Agent knowledge) still works.
|
||||
#
|
||||
# Usage:
|
||||
# docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml up -d
|
||||
# docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml up -d
|
||||
#
|
||||
# With dev ports:
|
||||
# docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml \
|
||||
# docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml \
|
||||
# -f docker-compose.dev.yml up -d --wait
|
||||
#
|
||||
# This overlay:
|
||||
# - Moves Vespa (index), both model servers, code-interpreter, Redis (cache),
|
||||
# and the background worker to profiles so they do not start by default
|
||||
# - Makes depends_on references to removed services optional
|
||||
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
|
||||
# so they do not start by default
|
||||
# - Moves the background worker to the "background" profile (the API server
|
||||
# handles all background work via FastAPI BackgroundTasks)
|
||||
# - Makes the depends_on references to removed services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on the api_server
|
||||
# - Uses PostgreSQL for caching and auth instead of Redis
|
||||
# - Uses PostgreSQL for file storage instead of S3/MinIO
|
||||
#
|
||||
# To selectively bring services back:
|
||||
# --profile vectordb Vespa + indexing model server
|
||||
# --profile inference Inference model server
|
||||
# --profile background Background worker (Celery) — also needs redis
|
||||
# --profile redis Redis cache
|
||||
# --profile background Background worker (Celery)
|
||||
# --profile code-interpreter Code interpreter
|
||||
# =============================================================================
|
||||
|
||||
@@ -38,9 +36,6 @@ services:
|
||||
index:
|
||||
condition: service_started
|
||||
required: false
|
||||
cache:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
@@ -50,11 +45,9 @@ services:
|
||||
environment:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
- CACHE_BACKEND=postgres
|
||||
- AUTH_BACKEND=postgres
|
||||
|
||||
# Move the background worker to a profile so it does not start by default.
|
||||
# The API server handles all background work in lite mode.
|
||||
# The API server handles all background work in NO_VECTOR_DB mode.
|
||||
background:
|
||||
profiles: ["background"]
|
||||
depends_on:
|
||||
@@ -68,11 +61,6 @@ services:
|
||||
condition: service_started
|
||||
required: false
|
||||
|
||||
# Move Redis to a profile so it does not start by default.
|
||||
# The Postgres cache backend replaces Redis in lite mode.
|
||||
cache:
|
||||
profiles: ["redis"]
|
||||
|
||||
# Move Vespa and indexing model server to a profile so they do not start.
|
||||
index:
|
||||
profiles: ["vectordb"]
|
||||
@@ -1,31 +0,0 @@
|
||||
# =============================================================================
|
||||
# ONYX LITE — MINIMAL DEPLOYMENT VALUES
|
||||
# =============================================================================
|
||||
# Minimal Onyx deployment: no vector database, no Redis, no model servers.
|
||||
# Only PostgreSQL is required. Connectors and RAG search are disabled, but the
|
||||
# core chat experience (LLM conversations, tools, user file uploads, Projects,
|
||||
# Agent knowledge) still works.
|
||||
#
|
||||
# Usage:
|
||||
# helm install onyx ./deployment/helm/charts/onyx \
|
||||
# -f ./deployment/helm/charts/onyx/values-lite.yaml
|
||||
#
|
||||
# Or merged with your own overrides:
|
||||
# helm install onyx ./deployment/helm/charts/onyx \
|
||||
# -f ./deployment/helm/charts/onyx/values-lite.yaml \
|
||||
# -f my-overrides.yaml
|
||||
# =============================================================================
|
||||
|
||||
vectorDB:
|
||||
enabled: false
|
||||
|
||||
vespa:
|
||||
enabled: false
|
||||
|
||||
redis:
|
||||
enabled: false
|
||||
|
||||
configMap:
|
||||
CACHE_BACKEND: "postgres"
|
||||
AUTH_BACKEND: "postgres"
|
||||
FILE_STORE_BACKEND: "postgres"
|
||||
@@ -1,93 +0,0 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of TODO(name): ... or TODO(1234): ..."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Remove temporary debugging code before merging to production, especially tenant-specific debugging logs."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +143,6 @@ module.exports = {
|
||||
"**/src/app/**/utils/*.test.ts",
|
||||
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
|
||||
"**/src/refresh-components/**/*.test.ts",
|
||||
"**/src/sections/**/*.test.ts",
|
||||
// Add more patterns here as you add more unit tests
|
||||
],
|
||||
},
|
||||
@@ -157,8 +156,6 @@ module.exports = {
|
||||
"**/src/components/**/*.test.tsx",
|
||||
"**/src/lib/**/*.test.tsx",
|
||||
"**/src/refresh-components/**/*.test.tsx",
|
||||
"**/src/hooks/**/*.test.tsx",
|
||||
"**/src/sections/**/*.test.tsx",
|
||||
// Add more patterns here as you add more integration tests
|
||||
],
|
||||
},
|
||||
|
||||
@@ -18,10 +18,6 @@
|
||||
"types": "./src/icons/index.ts",
|
||||
"default": "./src/icons/index.ts"
|
||||
},
|
||||
"./illustrations": {
|
||||
"types": "./src/illustrations/index.ts",
|
||||
"default": "./src/illustrations/index.ts"
|
||||
},
|
||||
"./types": {
|
||||
"types": "./src/types.ts",
|
||||
"default": "./src/types.ts"
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
# SVG-to-TSX Conversion Scripts
|
||||
|
||||
## Overview
|
||||
|
||||
Integrating `@svgr/webpack` into the TypeScript compiler was not working via the recommended route (Next.js webpack configuration).
|
||||
The automatic SVG-to-React component conversion was causing compilation issues and import resolution problems.
|
||||
Therefore, we manually convert each SVG into a TSX file using SVGR CLI with a custom template.
|
||||
|
||||
All scripts in this directory should be run from the **opal package root** (`web/lib/opal/`).
|
||||
|
||||
## Directory Layout
|
||||
|
||||
```
|
||||
web/lib/opal/
|
||||
├── scripts/ # SVG conversion tooling (this directory)
|
||||
│ ├── convert-svg.sh # Converts SVGs into React components
|
||||
│ └── icon-template.js # Shared SVGR template (used for both icons and illustrations)
|
||||
├── src/
|
||||
│ ├── icons/ # Small, single-colour icons (stroke = currentColor)
|
||||
│ └── illustrations/ # Larger, multi-colour illustrations (colours preserved)
|
||||
└── package.json
|
||||
```
|
||||
|
||||
## Icons vs Illustrations
|
||||
|
||||
| | Icons | Illustrations |
|
||||
|---|---|---|
|
||||
| **Import path** | `@opal/icons` | `@opal/illustrations` |
|
||||
| **Location** | `src/icons/` | `src/illustrations/` |
|
||||
| **Colour** | Overridable via `currentColor` | Fixed — original SVG colours preserved |
|
||||
| **Script flag** | (none) | `--illustration` |
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
### `icon-template.js`
|
||||
|
||||
A custom SVGR template that generates components with the following features:
|
||||
- Imports `IconProps` from `@opal/types` for consistent typing
|
||||
- Supports the `size` prop for controlling icon dimensions
|
||||
- Includes `width` and `height` attributes bound to the `size` prop
|
||||
- Maintains all standard SVG props (className, color, title, etc.)
|
||||
|
||||
### `convert-svg.sh`
|
||||
|
||||
Converts an SVG into a React component. Behaviour depends on the mode:
|
||||
|
||||
**Icon mode** (default):
|
||||
- Strips `stroke`, `stroke-opacity`, `width`, and `height` attributes
|
||||
- Adds `width={size}`, `height={size}`, and `stroke="currentColor"`
|
||||
- Result is colour-overridable via CSS `color` property
|
||||
|
||||
**Illustration mode** (`--illustration`):
|
||||
- Strips only `width` and `height` attributes (all colours preserved)
|
||||
- Adds `width={size}` and `height={size}`
|
||||
- Does **not** add `stroke="currentColor"` — illustrations keep their original colours
|
||||
|
||||
Both modes automatically delete the source SVG file after successful conversion.
|
||||
|
||||
## Adding New SVGs
|
||||
|
||||
### Icons
|
||||
|
||||
```sh
|
||||
# From web/lib/opal/
|
||||
./scripts/convert-svg.sh src/icons/my-icon.svg
|
||||
```
|
||||
|
||||
Then add the export to `src/icons/index.ts`:
|
||||
```ts
|
||||
export { default as SvgMyIcon } from "@opal/icons/my-icon";
|
||||
```
|
||||
|
||||
### Illustrations
|
||||
|
||||
```sh
|
||||
# From web/lib/opal/
|
||||
./scripts/convert-svg.sh --illustration src/illustrations/my-illustration.svg
|
||||
```
|
||||
|
||||
Then add the export to `src/illustrations/index.ts`:
|
||||
```ts
|
||||
export { default as SvgMyIllustration } from "@opal/illustrations/my-illustration";
|
||||
```
|
||||
|
||||
## Manual Conversion
|
||||
|
||||
If you prefer to run the SVGR command directly:
|
||||
|
||||
**For icons** (strips colours):
|
||||
```sh
|
||||
bunx @svgr/cli <file>.svg --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}' --template scripts/icon-template.js > <file>.tsx
|
||||
```
|
||||
|
||||
**For illustrations** (preserves colours):
|
||||
```sh
|
||||
bunx @svgr/cli <file>.svg --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["width","height"]}}]}' --template scripts/icon-template.js > <file>.tsx
|
||||
```
|
||||
|
||||
After running either manual command, remember to delete the original SVG file.
|
||||
@@ -1,123 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Convert an SVG file to a TypeScript React component.
|
||||
#
|
||||
# By default, converts to a colour-overridable icon (stroke colours stripped, replaced with currentColor).
|
||||
# With --illustration, converts to a fixed-colour illustration (all original colours preserved).
|
||||
#
|
||||
# Usage (from the opal package root — web/lib/opal/):
|
||||
# ./scripts/convert-svg.sh src/icons/<filename.svg>
|
||||
# ./scripts/convert-svg.sh --illustration src/illustrations/<filename.svg>
|
||||
|
||||
ILLUSTRATION=false
|
||||
|
||||
# Parse flags
|
||||
while [[ "$1" == --* ]]; do
|
||||
case "$1" in
|
||||
--illustration)
|
||||
ILLUSTRATION=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown flag: $1" >&2
|
||||
echo "Usage: ./scripts/convert-svg.sh [--illustration] <filename.svg>" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: ./scripts/convert-svg.sh [--illustration] <filename.svg>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SVG_FILE="$1"
|
||||
|
||||
# Check if file exists
|
||||
if [ ! -f "$SVG_FILE" ]; then
|
||||
echo "Error: File '$SVG_FILE' not found" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if it's an SVG file
|
||||
if [[ ! "$SVG_FILE" == *.svg ]]; then
|
||||
echo "Error: File must have .svg extension" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the base name without extension
|
||||
BASE_NAME="${SVG_FILE%.svg}"
|
||||
|
||||
# Build the SVGO config based on mode
|
||||
if [ "$ILLUSTRATION" = true ]; then
|
||||
# Illustrations: only strip width and height (preserve all colours)
|
||||
SVGO_CONFIG='{"plugins":[{"name":"removeAttrs","params":{"attrs":["width","height"]}}]}'
|
||||
else
|
||||
# Icons: strip stroke, stroke-opacity, width, and height
|
||||
SVGO_CONFIG='{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}'
|
||||
fi
|
||||
|
||||
# Resolve the template path relative to this script (not the caller's CWD)
|
||||
SCRIPT_DIR="$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Run the conversion into a temp file so a failed run doesn't destroy an existing .tsx
|
||||
TMPFILE="${BASE_NAME}.tsx.tmp"
|
||||
bunx @svgr/cli "$SVG_FILE" --typescript --svgo-config "$SVGO_CONFIG" --template "${SCRIPT_DIR}/icon-template.js" > "$TMPFILE"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
# Verify the temp file has content before replacing the destination
|
||||
if [ ! -s "$TMPFILE" ]; then
|
||||
rm -f "$TMPFILE"
|
||||
echo "Error: Output file was not created or is empty" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mv "$TMPFILE" "${BASE_NAME}.tsx" || { echo "Error: Failed to move temp file" >&2; exit 1; }
|
||||
|
||||
# Post-process the file to add width and height attributes bound to the size prop
|
||||
# Using perl for cross-platform compatibility (works on macOS, Linux, Windows with WSL)
|
||||
# Note: perl -i returns 0 even on some failures, so we validate the output
|
||||
|
||||
perl -i -pe 's/<svg/<svg width={size} height={size}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add width/height attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Icons additionally get stroke="currentColor"
|
||||
if [ "$ILLUSTRATION" = false ]; then
|
||||
perl -i -pe 's/\{\.\.\.props\}/stroke="currentColor" {...props}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add stroke attribute" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Verify the file still exists and has content after post-processing
|
||||
if [ ! -s "${BASE_NAME}.tsx" ]; then
|
||||
echo "Error: Output file corrupted during post-processing" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify required attributes are present in the output
|
||||
if ! grep -q 'width={size}' "${BASE_NAME}.tsx" || ! grep -q 'height={size}' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Post-processing did not add required attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# For icons, also verify stroke="currentColor" was added
|
||||
if [ "$ILLUSTRATION" = false ]; then
|
||||
if ! grep -q 'stroke="currentColor"' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Post-processing did not add stroke=\"currentColor\"" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Created ${BASE_NAME}.tsx"
|
||||
rm "$SVG_FILE"
|
||||
echo "Deleted $SVG_FILE"
|
||||
else
|
||||
rm -f "$TMPFILE"
|
||||
echo "Error: Conversion failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -11,8 +11,6 @@ export {
|
||||
Interactive,
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSidebarVariantProps,
|
||||
type InteractiveBaseSidebarProminenceTypes,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core/interactive/components";
|
||||
|
||||
@@ -104,44 +104,6 @@ The foundational layer for all clickable surfaces in the design system. Defines
|
||||
| **Active** | `action-link-05` | `action-link-05` |
|
||||
| **Disabled** | `action-link-03` | `action-link-03` |
|
||||
|
||||
### Sidebar (unselected)
|
||||
|
||||
> No CSS `:active` state — only hover/transient and selected.
|
||||
|
||||
**Background**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **Rest** | `transparent` |
|
||||
| **Hover / Transient** | `background-tint-03` |
|
||||
| **Disabled** | `transparent` |
|
||||
|
||||
**Foreground**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **Rest** | `text-03` |
|
||||
| **Hover / Transient** | `text-04` |
|
||||
| **Disabled** | `text-01` |
|
||||
|
||||
### Sidebar (selected)
|
||||
|
||||
> Completely static — hover and transient have no effect.
|
||||
|
||||
**Background**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **All states** | `background-tint-00` |
|
||||
| **Disabled** | `transparent` |
|
||||
|
||||
**Foreground**
|
||||
|
||||
| | Light |
|
||||
|---|---|
|
||||
| **All states** | `text-03` (icon: `text-02`) |
|
||||
| **Disabled** | `text-01` |
|
||||
|
||||
## Sub-components
|
||||
|
||||
| Sub-component | Role |
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import Link from "next/link";
|
||||
import type { Route } from "next";
|
||||
import "@opal/core/interactive/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
@@ -28,28 +26,18 @@ type InteractiveBaseSelectVariantProps = {
|
||||
selected?: boolean;
|
||||
};
|
||||
|
||||
type InteractiveBaseSidebarProminenceTypes = "light";
|
||||
type InteractiveBaseSidebarVariantProps = {
|
||||
variant: "sidebar";
|
||||
prominence?: InteractiveBaseSidebarProminenceTypes;
|
||||
selected?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Discriminated union tying `variant` to `prominence`.
|
||||
*
|
||||
* - `"none"` accepts no prominence (`prominence` must not be provided)
|
||||
* - `"select"` accepts an optional prominence (defaults to `"light"`) and
|
||||
* an optional `selected` boolean that switches foreground to action-link colours
|
||||
* - `"sidebar"` accepts an optional prominence (defaults to `"light"`) and
|
||||
* an optional `selected` boolean for the focused/active-item state
|
||||
* - `"default"`, `"action"`, and `"danger"` accept an optional prominence
|
||||
* (defaults to `"primary"`)
|
||||
*/
|
||||
type InteractiveBaseVariantProps =
|
||||
| { variant?: "none"; prominence?: never; selected?: never }
|
||||
| InteractiveBaseSelectVariantProps
|
||||
| InteractiveBaseSidebarVariantProps
|
||||
| {
|
||||
variant?: InteractiveBaseVariantTypes;
|
||||
prominence?: InteractiveBaseProminenceTypes;
|
||||
@@ -230,8 +218,7 @@ function InteractiveBase({
|
||||
...props
|
||||
}: InteractiveBaseProps) {
|
||||
const effectiveProminence =
|
||||
prominence ??
|
||||
(variant === "select" || variant === "sidebar" ? "light" : "primary");
|
||||
prominence ?? (variant === "select" ? "light" : "primary");
|
||||
const classes = cn(
|
||||
"interactive",
|
||||
!props.onClick && !href && "!cursor-default !select-auto",
|
||||
@@ -430,9 +417,9 @@ function InteractiveContainer({
|
||||
// so all styling (backgrounds, rounding, overflow) lives on one element.
|
||||
if (href) {
|
||||
return (
|
||||
<Link
|
||||
<a
|
||||
ref={ref as React.Ref<HTMLAnchorElement>}
|
||||
href={href as Route}
|
||||
href={href}
|
||||
target={target}
|
||||
rel={rel}
|
||||
{...(sharedProps as React.HTMLAttributes<HTMLAnchorElement>)}
|
||||
@@ -495,8 +482,6 @@ export {
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSelectVariantProps,
|
||||
type InteractiveBaseSidebarVariantProps,
|
||||
type InteractiveBaseSidebarProminenceTypes,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
};
|
||||
|
||||
@@ -419,23 +419,3 @@
|
||||
) {
|
||||
@apply bg-background-tint-00;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar + Light
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"] {
|
||||
@apply bg-transparent;
|
||||
}
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"][data-transient="true"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
.interactive[data-interactive-base-variant="sidebar"][data-interactive-base-prominence="light"][data-selected="true"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-00;
|
||||
}
|
||||
|
||||
58
web/lib/opal/src/icons/README.md
Normal file
58
web/lib/opal/src/icons/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Compilation of SVGs into TypeScript React Components
|
||||
|
||||
## Overview
|
||||
|
||||
Integrating `@svgr/webpack` into the TypeScript compiler was not working via the recommended route (Next.js webpack configuration).
|
||||
The automatic SVG-to-React component conversion was causing compilation issues and import resolution problems.
|
||||
Therefore, we manually convert each SVG into a TSX file using SVGR CLI with a custom template.
|
||||
|
||||
## Files in This Directory
|
||||
|
||||
### `scripts/icon-template.js`
|
||||
|
||||
A custom SVGR template that generates icon components with the following features:
|
||||
- Imports `IconProps` from `@opal/types` for consistent typing
|
||||
- Supports the `size` prop for controlling icon dimensions
|
||||
- Includes `width` and `height` attributes bound to the `size` prop
|
||||
- Maintains all standard SVG props (className, color, title, etc.)
|
||||
|
||||
This ensures all generated icons have a consistent API and type definitions.
|
||||
|
||||
### `scripts/convert-svg.sh`
|
||||
|
||||
A convenience script that automates the SVG-to-TSX conversion process. It:
|
||||
- Validates the input file
|
||||
- Runs SVGR with the correct configuration and template
|
||||
- Post-processes the output to add `width`, `height`, and `stroke` attributes using perl (cross-platform compatible)
|
||||
- Automatically deletes the source SVG file after successful conversion
|
||||
- Provides error handling and user feedback
|
||||
|
||||
**Usage:**
|
||||
```sh
|
||||
./scripts/convert-svg.sh <filename.svg>
|
||||
```
|
||||
|
||||
## Adding New SVGs
|
||||
|
||||
**Recommended Method:**
|
||||
|
||||
Use the conversion script for the easiest experience:
|
||||
|
||||
```sh
|
||||
./scripts/convert-svg.sh my-icon.svg
|
||||
```
|
||||
|
||||
**Manual Method:**
|
||||
|
||||
If you prefer to run the command directly:
|
||||
|
||||
```sh
|
||||
bunx @svgr/cli ${SVG_FILE_NAME}.svg --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}' --template scripts/icon-template.js > ${SVG_FILE_NAME}.tsx
|
||||
```
|
||||
|
||||
This command:
|
||||
- Converts SVG files to TypeScript React components (`--typescript`)
|
||||
- Removes `stroke`, `stroke-opacity`, `width`, and `height` attributes from SVG elements (`--svgo-config` with `removeAttrs` plugin)
|
||||
- Uses the custom template (`icon-template.js`) to generate components with `IconProps` and `size` prop support
|
||||
|
||||
After running the manual command, remember to delete the original SVG file.
|
||||
72
web/lib/opal/src/icons/scripts/convert-svg.sh
Executable file
72
web/lib/opal/src/icons/scripts/convert-svg.sh
Executable file
@@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Convert an SVG file to a TypeScript React component
|
||||
# Usage: ./convert-svg.sh <filename.svg>
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: ./convert-svg.sh <filename.svg>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SVG_FILE="$1"
|
||||
|
||||
# Check if file exists
|
||||
if [ ! -f "$SVG_FILE" ]; then
|
||||
echo "Error: File '$SVG_FILE' not found" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if it's an SVG file
|
||||
if [[ ! "$SVG_FILE" == *.svg ]]; then
|
||||
echo "Error: File must have .svg extension" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the base name without extension
|
||||
BASE_NAME="${SVG_FILE%.svg}"
|
||||
|
||||
# Run the conversion with relative path to template
|
||||
bunx @svgr/cli "$SVG_FILE" --typescript --svgo-config '{"plugins":[{"name":"removeAttrs","params":{"attrs":["stroke","stroke-opacity","width","height"]}}]}' --template "scripts/icon-template.js" > "${BASE_NAME}.tsx"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
# Verify the output file was created and has content
|
||||
if [ ! -s "${BASE_NAME}.tsx" ]; then
|
||||
echo "Error: Output file was not created or is empty" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Post-process the file to add width, height, and stroke attributes
|
||||
# Using perl for cross-platform compatibility (works on macOS, Linux, Windows with WSL)
|
||||
# Note: perl -i returns 0 even on some failures, so we validate the output
|
||||
|
||||
perl -i -pe 's/<svg/<svg width={size} height={size}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add width/height attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
perl -i -pe 's/\{\.\.\.props\}/stroke="currentColor" {...props}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to add stroke attribute" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify the file still exists and has content after post-processing
|
||||
if [ ! -s "${BASE_NAME}.tsx" ]; then
|
||||
echo "Error: Output file corrupted during post-processing" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify required attributes are present in the output
|
||||
if ! grep -q 'width={size}' "${BASE_NAME}.tsx" || ! grep -q 'stroke="currentColor"' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Post-processing did not add required attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Created ${BASE_NAME}.tsx"
|
||||
rm "$SVG_FILE"
|
||||
echo "Deleted $SVG_FILE"
|
||||
else
|
||||
echo "Error: Conversion failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgBrokenKey = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 120 120"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M54.375 43.125H43.125M69.375 28.125V16.875M58.125 31.875L48.75 22.5"
|
||||
stroke="#EC5B13"
|
||||
strokeWidth={3.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M108.75 18.75L98.5535 24.6369M98.5535 24.6369L104.044 34.1465L91.7404 41.25L86.25 31.7404M98.5535 24.6369L86.25 31.7404M86.25 31.7404L78.7499 36.0705M49.6599 62.8401C45.5882 58.7684 39.9632 56.25 33.75 56.25C21.3236 56.25 11.25 66.3236 11.25 78.75C11.25 91.1764 21.3236 101.25 33.75 101.25C46.1764 101.25 56.25 91.1764 56.25 78.75C56.25 72.5368 53.7316 66.9118 49.6599 62.8401ZM49.6599 62.8401L49.6406 62.8594M49.6599 62.8401L60 52.5"
|
||||
stroke="#A4A4A4"
|
||||
strokeWidth={3.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgBrokenKey;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user