mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 22:55:46 +00:00
Compare commits
35 Commits
v3.0.0-clo
...
action_too
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ab14dcb06 | ||
|
|
ca8729203a | ||
|
|
f3ffd6b9d4 | ||
|
|
add42b4109 | ||
|
|
d8532e7b6c | ||
|
|
3e67ea9df7 | ||
|
|
98e3602dd6 | ||
|
|
4fded5b0a1 | ||
|
|
328c305d26 | ||
|
|
f902727215 | ||
|
|
69c8aa08b3 | ||
|
|
c98aa486e4 | ||
|
|
03553114c5 | ||
|
|
6532c94230 | ||
|
|
1b32a7d94e | ||
|
|
5fd0fe192b | ||
|
|
1de522f9ae | ||
|
|
60fe3e9ad6 | ||
|
|
6aa56821d6 | ||
|
|
eda436de01 | ||
|
|
07915a6c01 | ||
|
|
2c3e9aecd1 | ||
|
|
fa29cc3849 | ||
|
|
24ac8b37d3 | ||
|
|
be8b108ae4 | ||
|
|
f380a75df3 | ||
|
|
21ec93663b | ||
|
|
d789c74024 | ||
|
|
fe014776f7 | ||
|
|
700ca0e0fc | ||
|
|
a84f8238ec | ||
|
|
4fc802e19d | ||
|
|
6cfd49439a | ||
|
|
71a1faa47e | ||
|
|
1a65217baf |
@@ -15,6 +15,7 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets: inherit
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -126,12 +126,16 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
external_id: str | None,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
@@ -270,8 +274,13 @@ class ScimDAL(DAL):
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -321,34 +330,37 @@ class ScimDAL(DAL):
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -167,6 +168,7 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,7 +15,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -24,6 +26,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
@@ -77,6 +80,22 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -404,21 +423,63 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
|
||||
# Enforce seat limit
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -436,18 +497,21 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
|
||||
@@ -19,7 +19,6 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -82,23 +96,14 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
|
||||
return token
|
||||
|
||||
@@ -153,26 +153,31 @@ class ScimProvider(ABC):
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
) -> ScimName:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -520,6 +520,7 @@ def process_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
@@ -675,6 +676,7 @@ def delete_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
@@ -849,6 +851,7 @@ def project_sync_user_file_impl(
|
||||
task_logger.exception(
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
@@ -59,6 +59,12 @@ def _run_auto_llm_update() -> None:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
@@ -100,12 +106,26 @@ def _run_scheduled_eval() -> None:
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
|
||||
@@ -75,31 +75,41 @@ def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
def _claim_next_deleting_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
# Commit to release the row lock promptly.
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
def _claim_next_sync_file(
|
||||
db_session: Session,
|
||||
exclude_ids: set[UUID] | None = None,
|
||||
) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
*exclude_ids* prevents re-processing the same file if the impl fails.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
stmt = (
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
@@ -113,7 +123,10 @@ def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
)
|
||||
if exclude_ids:
|
||||
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
|
||||
file_id = db_session.execute(stmt).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
@@ -135,11 +148,14 @@ def drain_processing_loop(tenant_id: str) -> None:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
try:
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process user file {file_id}")
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
@@ -149,16 +165,21 @@ def drain_delete_loop(tenant_id: str) -> None:
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session)
|
||||
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
try:
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
@@ -168,13 +189,18 @@ def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
failed: set[UUID] = set()
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session)
|
||||
file_id = _claim_next_sync_file(session, exclude_ids=failed)
|
||||
if file_id is None:
|
||||
break
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
try:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to sync user file {file_id}")
|
||||
failed.add(file_id)
|
||||
|
||||
8
backend/onyx/cache/factory.py
vendored
8
backend/onyx/cache/factory.py
vendored
@@ -12,9 +12,15 @@ def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
# CacheBackendType.POSTGRES will be added in a follow-up PR.
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
}
|
||||
|
||||
|
||||
|
||||
17
backend/onyx/cache/interface.py
vendored
17
backend/onyx/cache/interface.py
vendored
@@ -1,6 +1,9 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
TTL_KEY_NOT_FOUND = -2
|
||||
TTL_NO_EXPIRY = -1
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
@@ -26,6 +29,14 @@ class CacheLock(abc.ABC):
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
if not self.acquire():
|
||||
raise RuntimeError("Failed to acquire lock")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
@@ -65,7 +76,11 @@ class CacheBackend(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
"""Return remaining TTL in seconds.
|
||||
|
||||
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
|
||||
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
323
backend/onyx/cache/postgres_backend.py
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
|
||||
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
|
||||
# lists whose names share a prefix (e.g. _q:mylist2:...).
|
||||
_LIST_KEY_RANGE_TERMINATOR = ";"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_LOCK_POLL_INTERVAL = 0.1
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
|
||||
collision when concurrent rpush calls occur within the same nanosecond.
|
||||
"""
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
|
||||
to the session's connection; releasing or closing the session frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the session is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._tenant_id = tenant_id
|
||||
self._session_cm: AbstractContextManager[Session] | None = None
|
||||
self._session: Session | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
|
||||
self._session = self._session_cm.__enter__()
|
||||
try:
|
||||
if not blocking:
|
||||
return self._try_lock()
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (
|
||||
(time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
)
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(_LOCK_POLL_INTERVAL)
|
||||
finally:
|
||||
if not self._acquired:
|
||||
self._close_session()
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._session is None:
|
||||
return
|
||||
try:
|
||||
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._close_session()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _close_session(self) -> None:
|
||||
if self._session_cm is not None:
|
||||
try:
|
||||
self._session_cm.__exit__(None, None, None)
|
||||
finally:
|
||||
self._session_cm = None
|
||||
self._session = None
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._session is not None
|
||||
result = self._session.execute(
|
||||
select(func.pg_try_advisory_lock(self._lock_id))
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=ex)
|
||||
if ex is not None
|
||||
else None
|
||||
)
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return TTL_NO_EXPIRY
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return TTL_KEY_NOT_FOUND
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(
|
||||
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
if timeout <= 0:
|
||||
raise ValueError(
|
||||
"PostgresCacheBackend.blpop requires timeout > 0. "
|
||||
"timeout=0 would block the calling thread indefinitely "
|
||||
"with no way to interrupt short of process termination."
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
@@ -1,57 +1,52 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
"""Generate the cache key for a chat session processing fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
chat_session_id: UUID, cache: CacheBackend, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
"""Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
If the key exists, a message is being processed.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
return cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -11,9 +11,10 @@ from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
@@ -79,7 +80,6 @@ from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
cache: CacheBackend | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
redis_client = get_redis_client()
|
||||
cache = get_cache_backend()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
redis_client,
|
||||
cache,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,65 +1,58 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
FENCE_TTL = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
"""Generate the cache key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
The fence key string. Tenant isolation is handled automatically
|
||||
by the cache backend (Redis key-prefixing or Postgres schema routing).
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
|
||||
"""Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
cache.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
|
||||
"""Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
return not cache.exists(_get_fence_key(chat_session_id))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
|
||||
"""Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
cache: Tenant-aware cache backend
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
redis_client.delete(fence_key)
|
||||
cache.delete(_get_fence_key(chat_session_id))
|
||||
|
||||
@@ -819,7 +819,9 @@ RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
# Tool Configs
|
||||
#####
|
||||
# Code Interpreter Service Configuration
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get(
|
||||
"CODE_INTERPRETER_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int(
|
||||
os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000
|
||||
|
||||
@@ -532,6 +532,7 @@ def fetch_default_model(
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
|
||||
@@ -4926,7 +4926,9 @@ class ScimUserMapping(Base):
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
external_id: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
@@ -4983,3 +4985,25 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
@@ -4,39 +4,33 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key prefix for OAuth state
|
||||
OAUTH_STATE_PREFIX = "federated_oauth"
|
||||
# Default TTL for OAuth state (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
OAUTH_STATE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
class OAuthSession:
|
||||
"""Represents an OAuth session stored in Redis."""
|
||||
"""Represents an OAuth session stored in the cache backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
):
|
||||
self.federated_connector_id = federated_connector_id
|
||||
self.user_id = user_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.additional_data = additional_data or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for Redis storage."""
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"federated_connector_id": self.federated_connector_id,
|
||||
"user_id": self.user_id,
|
||||
@@ -45,8 +39,7 @@ class OAuthSession:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
|
||||
"""Create from dictionary retrieved from Redis."""
|
||||
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
|
||||
return cls(
|
||||
federated_connector_id=data["federated_connector_id"],
|
||||
user_id=data["user_id"],
|
||||
@@ -58,31 +51,27 @@ class OAuthSession:
|
||||
def generate_oauth_state(
|
||||
federated_connector_id: int,
|
||||
user_id: str,
|
||||
redirect_uri: Optional[str] = None,
|
||||
additional_data: Optional[Dict[str, Any]] = None,
|
||||
redirect_uri: str | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
ttl: int = OAUTH_STATE_TTL,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a secure state parameter and store session data in Redis.
|
||||
Generate a secure state parameter and store session data in the cache backend.
|
||||
|
||||
Args:
|
||||
federated_connector_id: ID of the federated connector
|
||||
user_id: ID of the user initiating OAuth
|
||||
redirect_uri: Optional redirect URI after OAuth completion
|
||||
additional_data: Any additional data to store with the session
|
||||
ttl: Time-to-live in seconds for the Redis key
|
||||
ttl: Time-to-live in seconds for the cache key
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter
|
||||
"""
|
||||
# Generate a random UUID for the state
|
||||
state_uuid = uuid.uuid4()
|
||||
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Convert UUID to base64 for URL-safe state parameter
|
||||
state_bytes = state_uuid.bytes
|
||||
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
|
||||
|
||||
# Create session object
|
||||
session = OAuthSession(
|
||||
federated_connector_id=federated_connector_id,
|
||||
user_id=user_id,
|
||||
@@ -90,15 +79,9 @@ def generate_oauth_state(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
|
||||
# Store in Redis with TTL
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
json.dumps(session.to_dict()),
|
||||
ex=ttl,
|
||||
)
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
|
||||
|
||||
logger.info(
|
||||
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
|
||||
@@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession:
|
||||
state_bytes = base64.urlsafe_b64decode(padded_state)
|
||||
state_uuid = uuid.UUID(bytes=state_bytes)
|
||||
|
||||
# Look up in Redis
|
||||
redis_client = get_redis_client()
|
||||
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
cache = get_cache_backend()
|
||||
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
|
||||
|
||||
session_data = cast(bytes, redis_client.get(redis_key))
|
||||
session_data = cache.get(cache_key)
|
||||
if not session_data:
|
||||
raise ValueError(f"OAuth state not found in Redis: {state}")
|
||||
raise ValueError(f"OAuth state not found: {state}")
|
||||
|
||||
# Delete the key after retrieval (one-time use)
|
||||
redis_client.delete(redis_key)
|
||||
cache.delete(cache_key)
|
||||
|
||||
# Parse and return session
|
||||
session_dict = json.loads(session_data)
|
||||
return OAuthSession.from_dict(session_dict)
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -20,22 +18,27 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client()
|
||||
def __init__(self, cache: CacheBackend | None = None) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _get_cache(self) -> CacheBackend:
|
||||
if self._cache is None:
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
self._cache = get_cache_backend()
|
||||
return self._cache
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback gracefully to Postgres if Redis fails
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
# Fallback gracefully to Postgres if Cache backend fails
|
||||
logger.error(
|
||||
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
@@ -53,16 +56,12 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
|
||||
if not refresh_cache:
|
||||
try:
|
||||
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
|
||||
if redis_value:
|
||||
if not isinstance(redis_value, bytes):
|
||||
raise ValueError(
|
||||
f"Redis value for key '{key}' is not a bytes object"
|
||||
)
|
||||
return json.loads(redis_value.decode("utf-8"))
|
||||
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
|
||||
if cached is not None:
|
||||
return json.loads(cached.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get value from Redis for key '{key}': {str(e)}"
|
||||
f"Failed to get value from cache for key '{key}': {str(e)}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -79,21 +78,21 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(
|
||||
self._get_cache().set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
|
||||
|
||||
return cast(JSON_ro, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
self.redis_client.delete(REDIS_KEY_PREFIX + key)
|
||||
self._get_cache().delete(REDIS_KEY_PREFIX + key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
|
||||
@@ -67,6 +67,18 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -725,6 +737,44 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -736,6 +786,7 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -743,3 +794,4 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -13,44 +13,38 @@ from datetime import datetime
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Redis key for caching the last updated timestamp (per-tenant)
|
||||
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
|
||||
|
||||
|
||||
def _get_cached_last_updated_at() -> datetime | None:
|
||||
"""Get the cached last_updated_at timestamp from Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
if value and isinstance(value, bytes):
|
||||
# Value is bytes, decode to string then parse as ISO format
|
||||
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
if value is not None:
|
||||
return datetime.fromisoformat(value.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
|
||||
logger.warning(f"Failed to get cached last_updated_at: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_last_updated_at(updated_at: datetime) -> None:
|
||||
"""Set the cached last_updated_at timestamp in Redis."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
# Store as ISO format string, with 24 hour expiration
|
||||
redis_client.set(
|
||||
_REDIS_KEY_LAST_UPDATED_AT,
|
||||
get_cache_backend().set(
|
||||
_CACHE_KEY_LAST_UPDATED_AT,
|
||||
updated_at.isoformat(),
|
||||
ex=60 * 60 * 24, # 24 hours
|
||||
ex=_CACHE_TTL_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
|
||||
logger.warning(f"Failed to set cached last_updated_at: {e}")
|
||||
|
||||
|
||||
def fetch_llm_recommendations_from_github(
|
||||
@@ -148,9 +142,8 @@ def sync_llm_models_from_github(
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the cache timestamp in Redis. Useful for testing."""
|
||||
"""Reset the cache timestamp. Useful for testing."""
|
||||
try:
|
||||
redis_client = get_redis_client()
|
||||
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
|
||||
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset cache in Redis: {e}")
|
||||
logger.warning(f"Failed to reset cache: {e}")
|
||||
|
||||
@@ -1133,7 +1133,8 @@ done
|
||||
# Already deleted
|
||||
service_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Service {service_name}: {e}")
|
||||
logger.error(f"Error deleting Service {service_name}: {e}")
|
||||
raise
|
||||
|
||||
pod_deleted = False
|
||||
try:
|
||||
@@ -1148,7 +1149,8 @@ done
|
||||
# Already deleted
|
||||
pod_deleted = True
|
||||
else:
|
||||
logger.warning(f"Error deleting Pod {pod_name}: {e}")
|
||||
logger.error(f"Error deleting Pod {pod_name}: {e}")
|
||||
raise
|
||||
|
||||
# Wait for resources to be fully deleted to prevent 409 conflicts
|
||||
# on immediate re-provisioning
|
||||
|
||||
@@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Prevent overlapping runs of this task
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -7,13 +7,14 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -116,7 +117,9 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -128,11 +131,11 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
llm = get_default_llm()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -8,10 +8,10 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.cache.factory import get_shared_cache_backend
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
@@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
etag = cache.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return etag.decode("utf-8")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
logger.error(f"Failed to get cached etag: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
try:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
raw = cache.get(REDIS_KEY_FETCHED_AT)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
logger.error(f"Failed to get last fetch time from cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
cache = get_shared_cache_backend()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
logger.error(f"Failed to save fetch metadata to cache: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
@@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
cache = get_shared_cache_backend()
|
||||
lock = cache.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
timeout=90,
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
|
||||
@@ -479,10 +479,20 @@ def put_llm_provider(
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
force: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -13,13 +13,13 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import convert_chat_history_basic
|
||||
@@ -67,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
|
||||
from onyx.server.api_key_usage import check_api_key_usage
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
@@ -330,7 +329,7 @@ def get_chat_session(
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
is_processing = is_chat_session_processing(session_id, get_cache_backend())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
@@ -927,11 +926,10 @@ async def search_chats(
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User = Depends(current_user), # noqa: ARG001
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
Stop a chat session by setting a stop signal.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
set_fence(chat_session_id, get_cache_backend(), True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
@@ -6,11 +7,8 @@ from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -33,30 +31,22 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
try:
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
|
||||
except Exception as e:
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
# Override user knowledge setting if disabled via environment variable
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
@@ -66,11 +56,10 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache = get_cache_backend()
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
@@ -12,6 +13,9 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HEALTH_CACHE_TTL_SECONDS = 30
|
||||
_health_cache: dict[str, tuple[float, bool]] = {}
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
@@ -98,16 +102,32 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
cached_at, cached_result = cached
|
||||
if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS:
|
||||
return cached_result
|
||||
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
result = response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
result = False
|
||||
|
||||
_health_cache[self.base_url] = (time.monotonic(), result)
|
||||
return result
|
||||
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -107,7 +107,11 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.4
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -13,9 +13,11 @@ the correct files.
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
@@ -55,6 +57,32 @@ def _create_user_file(
|
||||
return uf
|
||||
|
||||
|
||||
def _fake_delete_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: delete the row so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _fake_sync_impl(
|
||||
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
"""Mock side-effect: clear sync flags so the drain loop terminates."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == UUID(user_file_id))
|
||||
.values(needs_project_sync=False, needs_persona_sync=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
@@ -125,9 +153,9 @@ class TestRecoverDeletingFiles:
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf)
|
||||
# Row is deleted by _fake_delete_impl, so no cleanup needed.
|
||||
|
||||
mock_impl = MagicMock()
|
||||
mock_impl = MagicMock(side_effect=_fake_delete_impl)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -155,7 +183,7 @@ class TestRecoverSyncFiles:
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -179,7 +207,7 @@ class TestRecoverSyncFiles:
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
mock_impl = MagicMock(side_effect=_fake_sync_impl)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
@@ -217,3 +245,108 @@ class TestRecoveryMultipleFiles:
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestTransientFailures:
|
||||
"""Drain loops skip failed files, process the rest, and terminate."""
|
||||
|
||||
def test_processing_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_proc")
|
||||
uf_fail = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_delete_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_del")
|
||||
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf_fail)
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
def test_sync_failure_skips_and_continues(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "fail_sync")
|
||||
uf_fail = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
uf_ok = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.extend([uf_fail, uf_ok])
|
||||
|
||||
fail_id = str(uf_fail.id)
|
||||
|
||||
def side_effect(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
) -> None:
|
||||
if user_file_id == fail_id:
|
||||
raise RuntimeError("transient failure")
|
||||
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
|
||||
|
||||
mock_impl = MagicMock(side_effect=side_effect)
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert fail_id in called_ids, "Failed file should have been attempted"
|
||||
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
|
||||
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
|
||||
assert called_ids.count(str(uf_ok.id)) == 1
|
||||
|
||||
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
57
backend/tests/external_dependency_unit/cache/conftest.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
100
backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
129
backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
|
||||
|
||||
Verifies that the KV store correctly uses the CacheBackend for caching
|
||||
in front of PostgreSQL: cache hits, cache misses falling through to PG,
|
||||
cache population after PG reads, cache invalidation on delete, and
|
||||
graceful degradation when the cache backend raises.
|
||||
|
||||
Requires running PostgreSQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import CacheStore
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.key_value_store.store import REDIS_KEY_PREFIX
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_kv() -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(KVStore))
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
|
||||
return PgRedisKVStore(cache=pg_cache)
|
||||
|
||||
|
||||
class TestStoreAndLoad:
|
||||
def test_store_populates_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k1", {"hello": "world"})
|
||||
|
||||
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
|
||||
assert cached is not None
|
||||
assert json.loads(cached) == {"hello": "world"}
|
||||
|
||||
loaded = kv_store.load("k1")
|
||||
assert loaded == {"hello": "world"}
|
||||
|
||||
def test_load_returns_cached_value_without_pg_hit(
|
||||
self, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
"""If the cache already has the value, PG should not be queried."""
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
|
||||
kv = PgRedisKVStore(cache=pg_cache)
|
||||
assert kv.load("cached_only") == {"from": "cache"}
|
||||
|
||||
def test_load_falls_through_to_pg_on_cache_miss(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k2", [1, 2, 3])
|
||||
|
||||
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
|
||||
|
||||
loaded = kv_store.load("k2")
|
||||
assert loaded == [1, 2, 3]
|
||||
|
||||
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
|
||||
assert repopulated is not None
|
||||
assert json.loads(repopulated) == [1, 2, 3]
|
||||
|
||||
def test_load_with_refresh_cache_skips_cache(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("k3", "original")
|
||||
|
||||
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
|
||||
|
||||
loaded = kv_store.load("k3", refresh_cache=True)
|
||||
assert loaded == "original"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_removes_from_cache_and_pg(
|
||||
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
|
||||
) -> None:
|
||||
kv_store.store("del_me", "bye")
|
||||
kv_store.delete("del_me")
|
||||
|
||||
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
|
||||
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.load("del_me")
|
||||
|
||||
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
|
||||
with pytest.raises(KvKeyNotFoundError):
|
||||
kv_store.delete("nonexistent")
|
||||
|
||||
|
||||
class TestCacheFailureGracefulDegradation:
|
||||
def test_store_succeeds_when_cache_set_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("resilient", {"data": True})
|
||||
|
||||
working_cache = MagicMock(spec=CacheBackend)
|
||||
working_cache.get.return_value = None
|
||||
kv_reader = PgRedisKVStore(cache=working_cache)
|
||||
loaded = kv_reader.load("resilient")
|
||||
assert loaded == {"data": True}
|
||||
|
||||
def test_load_falls_through_when_cache_get_raises(self) -> None:
|
||||
failing_cache = MagicMock(spec=CacheBackend)
|
||||
failing_cache.get.side_effect = ConnectionError("cache down")
|
||||
failing_cache.set.side_effect = ConnectionError("cache down")
|
||||
|
||||
kv = PgRedisKVStore(cache=failing_cache)
|
||||
kv.store("survive", 42)
|
||||
loaded = kv.load("survive")
|
||||
assert loaded == 42
|
||||
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
229
backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py
vendored
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.cache.interface import TTL_KEY_NOT_FOUND
|
||||
from onyx.cache.interface import TTL_NO_EXPIRY
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
stmt = select(CacheStore.key).where(CacheStore.key == k)
|
||||
with get_session_with_current_tenant() as session:
|
||||
row = session.execute(stmt).first()
|
||||
assert row is None, "expired row should be physically deleted"
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -1027,6 +1027,13 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self._capture("GET", b"")
|
||||
if self.path == "/health":
|
||||
self._respond_json(200, {"status": "ok"})
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
|
||||
def do_DELETE(self) -> None:
|
||||
self._capture("DELETE", b"")
|
||||
self.send_response(200)
|
||||
@@ -1107,6 +1114,14 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _attach_python_tool_to_default_persona(db_session: Session) -> None:
|
||||
"""Ensure the default persona (id=0) has the PythonTool attached."""
|
||||
|
||||
@@ -386,6 +386,261 @@ def test_delete_llm_provider(
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Deleting the default LLM provider should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-default-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete the default provider — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
|
||||
|
||||
def test_delete_non_default_llm_provider_with_default_set(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a non-default provider should succeed even when a default is set."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create two providers
|
||||
response_default = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "default-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_default.status_code == 200
|
||||
default_provider = response_default.json()
|
||||
|
||||
response_other = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "other-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response_other.status_code == 200
|
||||
other_provider = response_other.json()
|
||||
|
||||
# Set the first provider as default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": default_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Delete the non-default provider — should succeed
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the non-default provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, other_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify the default provider still exists
|
||||
default_data = _get_provider_by_id(admin_user, default_provider["id"])
|
||||
assert default_data is not None
|
||||
|
||||
|
||||
def test_force_delete_default_llm_provider(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Force-deleting the default LLM provider should succeed."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-force-delete",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Set this provider as the default
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"provider_id": created_provider["id"],
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert set_default_response.status_code == 200
|
||||
|
||||
# Attempt to delete without force — should be rejected
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
|
||||
# Force delete — should succeed
|
||||
force_delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert force_delete_response.status_code == 200
|
||||
|
||||
# Verify provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
|
||||
def test_delete_default_vision_provider_clears_vision_default(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting the default vision provider should succeed and clear the vision default."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a text provider and set it as default (so we have a default text provider)
|
||||
text_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "text-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert text_response.status_code == 200
|
||||
text_provider = text_response.json()
|
||||
_set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini")
|
||||
|
||||
# Create a vision provider and set it as default vision
|
||||
vision_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "vision-provider",
|
||||
"provider": LlmProviderNames.OPENAI,
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000002",
|
||||
"model_configurations": [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
supports_image_input=True,
|
||||
).model_dump()
|
||||
],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert vision_response.status_code == 200
|
||||
vision_provider = vision_response.json()
|
||||
_set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o")
|
||||
|
||||
# Verify vision default is set
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, _, vision_default = _unpack_data(data)
|
||||
assert vision_default is not None
|
||||
assert vision_default["provider_id"] == vision_provider["id"]
|
||||
|
||||
# Delete the vision provider — should succeed (only text default is protected)
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
# Verify the vision provider is gone
|
||||
provider_data = _get_provider_by_id(admin_user, vision_provider["id"])
|
||||
assert provider_data is None
|
||||
|
||||
# Verify there is no default vision provider
|
||||
data = _get_providers_admin(admin_user)
|
||||
assert data is not None
|
||||
_, text_default, vision_default = _unpack_data(data)
|
||||
assert vision_default is None
|
||||
|
||||
# Verify the text default is still intact
|
||||
assert text_default is not None
|
||||
assert text_default["provider_id"] == text_provider["id"]
|
||||
|
||||
|
||||
def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
"""Creating a provider with a name that already exists should return 400."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -389,19 +389,22 @@ def test_delete_user(scim_token: str, idp_style: str) -> None:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user_missing_external_id(scim_token: str) -> None:
|
||||
"""POST /Users without externalId returns 400."""
|
||||
def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users without externalId succeeds (RFC 7643: externalId is optional)."""
|
||||
email = f"scim_no_extid_{idp_style}@example.com"
|
||||
resp = ScimClient.post(
|
||||
"/Users",
|
||||
scim_token,
|
||||
json={
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": "scim_no_extid@example.com",
|
||||
"userName": email,
|
||||
"active": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "externalId" in resp.json()["detail"]
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body.get("externalId") is None
|
||||
|
||||
|
||||
def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None:
|
||||
|
||||
166
backend/tests/unit/onyx/chat/test_stop_signal_checker.py
Normal file
166
backend/tests/unit/onyx/chat/test_stop_signal_checker.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Unit tests for stop_signal_checker and chat_processing_checker.
|
||||
|
||||
These modules are safety-critical — they control whether a chat stream
|
||||
continues or stops. The tests use a simple in-memory CacheBackend stub
|
||||
so no external services are needed.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.stop_signal_checker import FENCE_TTL
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None, # noqa: ARG002
|
||||
) -> None:
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── stop_signal_checker ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetFence:
|
||||
def test_set_fence_true_creates_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_removes_key(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_false_noop_when_absent(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, False)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_set_fence_uses_ttl(self) -> None:
|
||||
"""Verify set_fence passes ex=FENCE_TTL to cache.set."""
|
||||
calls: list[dict[str, object]] = []
|
||||
cache = _MemoryCacheBackend()
|
||||
original_set = cache.set
|
||||
|
||||
def tracking_set(
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
calls.append({"key": key, "ex": ex})
|
||||
original_set(key, value, ex=ex)
|
||||
|
||||
cache.set = tracking_set # type: ignore[method-assign]
|
||||
|
||||
set_fence(uuid4(), cache, True)
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["ex"] == FENCE_TTL
|
||||
|
||||
|
||||
class TestIsConnected:
|
||||
def test_connected_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert is_connected(uuid4(), cache)
|
||||
|
||||
def test_disconnected_when_fence_set(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
assert not is_connected(sid, cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_fence(sid1, cache, True)
|
||||
assert not is_connected(sid1, cache)
|
||||
assert is_connected(sid2, cache)
|
||||
|
||||
|
||||
class TestResetCancelStatus:
|
||||
def test_clears_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_fence(sid, cache, True)
|
||||
reset_cancel_status(sid, cache)
|
||||
assert is_connected(sid, cache)
|
||||
|
||||
def test_noop_when_no_fence(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
reset_cancel_status(uuid4(), cache)
|
||||
|
||||
|
||||
# ── chat_processing_checker ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetProcessingStatus:
|
||||
def test_set_true_marks_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
assert is_chat_session_processing(sid, cache)
|
||||
|
||||
def test_set_false_clears_processing(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid = uuid4()
|
||||
set_processing_status(sid, cache, True)
|
||||
set_processing_status(sid, cache, False)
|
||||
assert not is_chat_session_processing(sid, cache)
|
||||
|
||||
|
||||
class TestIsChatSessionProcessing:
|
||||
def test_not_processing_by_default(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
assert not is_chat_session_processing(uuid4(), cache)
|
||||
|
||||
def test_sessions_are_isolated(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
sid1, sid2 = uuid4(), uuid4()
|
||||
set_processing_status(sid1, cache, True)
|
||||
assert is_chat_session_processing(sid1, cache)
|
||||
assert not is_chat_session_processing(sid2, cache)
|
||||
163
backend/tests/unit/onyx/federated_connectors/test_oauth_utils.py
Normal file
163
backend/tests/unit/onyx/federated_connectors/test_oauth_utils.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Unit tests for federated OAuth state generation and verification.
|
||||
|
||||
Uses unittest.mock to patch get_cache_backend so no external services
|
||||
are needed. Verifies the generate -> verify round-trip, one-time-use
|
||||
semantics, TTL propagation, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.federated_connectors.oauth_utils import generate_oauth_state
|
||||
from onyx.federated_connectors.oauth_utils import OAUTH_STATE_TTL
|
||||
from onyx.federated_connectors.oauth_utils import OAuthSession
|
||||
from onyx.federated_connectors.oauth_utils import verify_oauth_state
|
||||
|
||||
|
||||
class _MemoryCacheBackend(CacheBackend):
|
||||
"""Minimal in-memory CacheBackend for unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, bytes] = {}
|
||||
self.set_calls: list[dict[str, object]] = []
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._store.get(key)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self.set_calls.append({"key": key, "ex": ex})
|
||||
if isinstance(value, bytes):
|
||||
self._store[key] = value
|
||||
else:
|
||||
self._store[key] = str(value).encode()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self._store
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
pass
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return -2 if key not in self._store else -1
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _patched(cache: _MemoryCacheBackend): # type: ignore[no-untyped-def]
|
||||
return patch(
|
||||
"onyx.federated_connectors.oauth_utils.get_cache_backend",
|
||||
return_value=cache,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateAndVerifyRoundTrip:
|
||||
def test_round_trip_basic(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=42,
|
||||
user_id="user-abc",
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 42
|
||||
assert session.user_id == "user-abc"
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
|
||||
def test_round_trip_with_all_fields(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(
|
||||
federated_connector_id=7,
|
||||
user_id="user-xyz",
|
||||
redirect_uri="https://example.com/callback",
|
||||
additional_data={"scope": "read"},
|
||||
)
|
||||
session = verify_oauth_state(state)
|
||||
|
||||
assert session.federated_connector_id == 7
|
||||
assert session.user_id == "user-xyz"
|
||||
assert session.redirect_uri == "https://example.com/callback"
|
||||
assert session.additional_data == {"scope": "read"}
|
||||
|
||||
|
||||
class TestOneTimeUse:
|
||||
def test_verify_deletes_state(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
verify_oauth_state(state)
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestTTLPropagation:
|
||||
def test_default_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
|
||||
assert len(cache.set_calls) == 1
|
||||
assert cache.set_calls[0]["ex"] == OAUTH_STATE_TTL
|
||||
|
||||
def test_custom_ttl(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
generate_oauth_state(federated_connector_id=1, user_id="u", ttl=600)
|
||||
|
||||
assert cache.set_calls[0]["ex"] == 600
|
||||
|
||||
|
||||
class TestVerifyInvalidState:
|
||||
def test_missing_state_raises(self) -> None:
|
||||
cache = _MemoryCacheBackend()
|
||||
with _patched(cache):
|
||||
state = generate_oauth_state(federated_connector_id=1, user_id="u")
|
||||
# Manually clear the cache to simulate expiration
|
||||
cache._store.clear()
|
||||
|
||||
with pytest.raises(ValueError, match="OAuth state not found"):
|
||||
verify_oauth_state(state)
|
||||
|
||||
|
||||
class TestOAuthSessionSerialization:
|
||||
def test_to_dict_from_dict_round_trip(self) -> None:
|
||||
session = OAuthSession(
|
||||
federated_connector_id=5,
|
||||
user_id="u-123",
|
||||
redirect_uri="https://redir.example.com",
|
||||
additional_data={"key": "val"},
|
||||
)
|
||||
d = session.to_dict()
|
||||
restored = OAuthSession.from_dict(d)
|
||||
|
||||
assert restored.federated_connector_id == 5
|
||||
assert restored.user_id == "u-123"
|
||||
assert restored.redirect_uri == "https://redir.example.com"
|
||||
assert restored.additional_data == {"key": "val"}
|
||||
|
||||
def test_from_dict_defaults(self) -> None:
|
||||
minimal = {"federated_connector_id": 1, "user_id": "u"}
|
||||
session = OAuthSession.from_dict(minimal)
|
||||
assert session.redirect_uri is None
|
||||
assert session.additional_data == {}
|
||||
@@ -1,11 +1,11 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.scim.auth import _hash_scim_token
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.auth import SCIM_TOKEN_PREFIX
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestVerifyScimToken:
|
||||
def test_missing_header_raises_401(self) -> None:
|
||||
request = self._make_request(None)
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing" in str(exc_info.value.detail)
|
||||
@@ -68,7 +68,7 @@ class TestVerifyScimToken:
|
||||
def test_wrong_prefix_raises_401(self) -> None:
|
||||
request = self._make_request("Bearer on_some_api_key")
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestVerifyScimToken:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
dal = self._make_dal(token=None)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in str(exc_info.value.detail)
|
||||
@@ -87,7 +87,7 @@ class TestVerifyScimToken:
|
||||
mock_token = MagicMock()
|
||||
mock_token.is_active = False
|
||||
dal = self._make_dal(token=mock_token)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestOktaProvider:
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Madonna", familyName=None, formatted="Madonna"
|
||||
givenName="Madonna", familyName="", formatted="Madonna"
|
||||
)
|
||||
|
||||
def test_build_user_resource_no_name(self) -> None:
|
||||
@@ -117,7 +117,10 @@ class TestOktaProvider:
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name is None
|
||||
# Falls back to deriving name from email local part
|
||||
assert result.name == ScimName(
|
||||
givenName="test", familyName="", formatted="test"
|
||||
)
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
|
||||
@@ -214,13 +214,17 @@ class TestCreateUser:
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_missing_external_id_returns_400(
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_missing_external_id_still_creates_mapping(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Mapping is always created to mark user as SCIM-managed."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = create_user(
|
||||
@@ -230,10 +234,14 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
parsed = parse_scim_user(result, status=201)
|
||||
assert parsed.userName is not None
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_duplicate_email_returns_409(
|
||||
def test_duplicate_scim_managed_email_returns_409(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
@@ -241,7 +249,12 @@ class TestCreateUser:
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = make_db_user()
|
||||
"""409 only when the existing user already has a SCIM mapping."""
|
||||
existing = make_db_user()
|
||||
mock_dal.get_user_by_email.return_value = existing
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = make_user_mapping(
|
||||
user_id=existing.id
|
||||
)
|
||||
resource = make_scim_user()
|
||||
|
||||
result = create_user(
|
||||
@@ -253,6 +266,40 @@ class TestCreateUser:
|
||||
|
||||
assert_scim_error(result, 409)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_existing_user_without_mapping_gets_linked(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Pre-existing user without SCIM mapping gets adopted (linked)."""
|
||||
existing = make_db_user(email="admin@example.com", personal_name=None)
|
||||
mock_dal.get_user_by_email.return_value = existing
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
resource = make_scim_user(userName="admin@example.com", externalId="ext-admin")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_user(result, status=201)
|
||||
assert parsed.userName == "admin@example.com"
|
||||
# Should NOT create a new user — reuse existing
|
||||
mock_dal.add_user.assert_not_called()
|
||||
# Should sync is_active and personal_name from the SCIM request
|
||||
mock_dal.update_user.assert_called_once_with(
|
||||
existing, is_active=True, personal_name="Test User"
|
||||
)
|
||||
# Should create a SCIM mapping for the existing user
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_integrity_error_returns_409(
|
||||
self,
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
"""Tests for PythonTool availability based on server_enabled flag and health check.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
- CodeInterpreterServer.server_enabled is False in the database, or
|
||||
- The Code Interpreter service health check fails.
|
||||
|
||||
Also verifies that the health check result is cached with a TTL.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
|
||||
CLIENT_MODULE = "onyx.tools.tool_implementations.python.code_interpreter_client"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_health_cache() -> None:
|
||||
"""Reset the health check cache before every test."""
|
||||
import onyx.tools.tool_implementations.python.code_interpreter_client as mod
|
||||
|
||||
mod._health_cache = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", None)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -27,10 +39,7 @@ def test_python_tool_unavailable_without_base_url() -> None:
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "")
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
@@ -43,13 +52,8 @@ def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
@@ -64,18 +68,15 @@ def test_python_tool_unavailable_when_server_disabled(
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Available when both conditions are met
|
||||
# Health check determines availability when URL + server are OK
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_available_when_health_check_passes(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
@@ -84,5 +85,120 @@ def test_python_tool_available_when_server_enabled(
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_python_tool_unavailable_when_health_check_fails(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client.health.assert_called_once_with(use_cache=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check is NOT reached when preconditions fail
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://localhost:8000")
|
||||
@patch(f"{TOOL_MODULE}.fetch_code_interpreter_server")
|
||||
@patch(f"{TOOL_MODULE}.CodeInterpreterClient")
|
||||
def test_health_check_not_called_when_server_disabled(
|
||||
mock_client_cls: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
mock_client_cls.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health check caching (tested at the client level)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_health_check_cached_on_second_call() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health(use_cache=True) is True
|
||||
assert client.health(use_cache=True) is True
|
||||
# Only one HTTP call — the second used the cache
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{CLIENT_MODULE}.time")
|
||||
def test_health_check_refreshed_after_ttl_expires(mock_time: MagicMock) -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
_HEALTH_CACHE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
# First call at t=0 — cache miss
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Second call within TTL — cache hit
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS - 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Third call after TTL — cache miss, fresh request
|
||||
mock_time.monotonic.return_value = float(_HEALTH_CACHE_TTL_SECONDS + 1)
|
||||
assert client.health(use_cache=True) is True
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_health_check_no_cache_by_default() -> None:
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
client = CodeInterpreterClient(base_url="http://fake:9000")
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch.object(client.session, "get", return_value=mock_response) as mock_get:
|
||||
assert client.health() is True
|
||||
assert client.health() is True
|
||||
# Both calls hit the network when use_cache=False (default)
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@@ -126,7 +126,9 @@ Resources:
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- secretsmanager:GetSecretValue
|
||||
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
Resource:
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
|
||||
|
||||
Outputs:
|
||||
OutputEcsCluster:
|
||||
|
||||
@@ -167,10 +167,12 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: disabled
|
||||
Value: basic
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
|
||||
@@ -166,9 +166,11 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: disabled
|
||||
Value: basic
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
@@ -19,6 +19,6 @@ dependencies:
|
||||
version: 5.4.0
|
||||
- name: code-interpreter
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
version: 0.3.0
|
||||
digest: sha256:cf8f01906d46034962c6ce894770621ee183ac761e6942951118aeb48540eddd
|
||||
generated: "2026-02-24T10:59:38.78318-08:00"
|
||||
version: 0.3.1
|
||||
digest: sha256:4965b6ea3674c37163832a2192cd3bc8004f2228729fca170af0b9f457e8f987
|
||||
generated: "2026-03-02T15:29:39.632344-08:00"
|
||||
|
||||
@@ -45,6 +45,6 @@ dependencies:
|
||||
repository: https://charts.min.io/
|
||||
condition: minio.enabled
|
||||
- name: code-interpreter
|
||||
version: 0.3.0
|
||||
version: 0.3.1
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
condition: codeInterpreter.enabled
|
||||
|
||||
72
greptile.json
Normal file
72
greptile.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": ""
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": ""
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "",
|
||||
"description": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.7.4",
|
||||
"pypdf==6.7.5",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -4678,7 +4678,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.4" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.5" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -5925,11 +5925,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.7.4"
|
||||
version = "6.7.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
20
web/lib/opal/src/icons/bookmark.tsx
Normal file
20
web/lib/opal/src/icons/bookmark.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgBookmark = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.6667 14L7.99999 10.6667L3.33333 14V3.33333C3.33333 2.97971 3.4738 2.64057 3.72385 2.39052C3.9739 2.14048 4.31304 2 4.66666 2H11.3333C11.6869 2 12.0261 2.14048 12.2761 2.39052C12.5262 2.64057 12.6667 2.97971 12.6667 3.33333V14Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgBookmark;
|
||||
@@ -25,6 +25,7 @@ export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
|
||||
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
|
||||
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
|
||||
|
||||
@@ -78,6 +78,16 @@ const nextConfig = {
|
||||
},
|
||||
async rewrites() {
|
||||
return [
|
||||
{
|
||||
source: "/ph_ingest/static/:path*",
|
||||
destination: "https://us-assets.i.posthog.com/static/:path*",
|
||||
},
|
||||
{
|
||||
source: "/ph_ingest/:path*",
|
||||
destination: `${
|
||||
process.env.NEXT_PUBLIC_POSTHOG_HOST || "https://us.i.posthog.com"
|
||||
}/:path*`,
|
||||
},
|
||||
{
|
||||
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
|
||||
destination: `${
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
"test:verbose": "jest --verbose",
|
||||
"test:ci": "jest --ci --maxWorkers=2 --silent --bail",
|
||||
"test:changed": "jest --onlyChanged",
|
||||
"test:diff": "jest --changedSince=main",
|
||||
"test:debug": "node --inspect-brk node_modules/.bin/jest --runInBand"
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { SvgMcp } from "@opal/icons";
|
||||
import MCPPageContent from "@/sections/actions/MCPPageContent";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.MCP_ACTIONS]!;
|
||||
|
||||
export default function Main() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgMcp}
|
||||
title="MCP Actions"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your agents."
|
||||
separator
|
||||
/>
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { SvgActions } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import OpenApiPageContent from "@/sections/actions/OpenApiPageContent";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.OPENAPI_ACTIONS]!;
|
||||
|
||||
export default function Main() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgActions}
|
||||
title="OpenAPI Actions"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Connect OpenAPI servers to add custom actions and tools for your agents."
|
||||
separator
|
||||
/>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"use client";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SourceCategory, SourceMetadata } from "@/lib/search/interfaces";
|
||||
import { listSourceMetadata } from "@/lib/sources";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
@@ -32,7 +32,7 @@ import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import SourceTile from "@/components/SourceTile";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgUploadCloud } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
function SourceTileTooltipWrapper({
|
||||
sourceMetadata,
|
||||
preSelect,
|
||||
@@ -124,6 +124,7 @@ function SourceTileTooltipWrapper({
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.ADD_CONNECTOR]!;
|
||||
const sources = useMemo(() => listSourceMetadata(), []);
|
||||
|
||||
const [rawSearchTerm, setSearchTerm] = useState("");
|
||||
@@ -248,61 +249,37 @@ export default function Page() {
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
icon={SvgUploadCloud}
|
||||
title="Add Connector"
|
||||
farRightElement={
|
||||
<SettingsLayouts.Root width="full">
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
rightChildren={
|
||||
<Button href="/admin/indexing/status" primary>
|
||||
See Connectors
|
||||
</Button>
|
||||
}
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<InputTypeIn
|
||||
type="text"
|
||||
placeholder="Search Connectors"
|
||||
ref={searchInputRef}
|
||||
value={rawSearchTerm} // keep the input bound to immediate state
|
||||
onChange={(event) => setSearchTerm(event.target.value)}
|
||||
onKeyDown={handleKeyPress}
|
||||
className="w-96 flex-none"
|
||||
/>
|
||||
|
||||
<InputTypeIn
|
||||
type="text"
|
||||
placeholder="Search Connectors"
|
||||
ref={searchInputRef}
|
||||
value={rawSearchTerm} // keep the input bound to immediate state
|
||||
onChange={(event) => setSearchTerm(event.target.value)}
|
||||
onKeyDown={handleKeyPress}
|
||||
className="w-96 flex-none"
|
||||
/>
|
||||
|
||||
{dedupedPopular.length > 0 && (
|
||||
<div className="pt-8">
|
||||
<Text as="p" headingH3>
|
||||
Popular
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-4 p-4">
|
||||
{dedupedPopular.map((source) => (
|
||||
<SourceTileTooltipWrapper
|
||||
preSelect={false}
|
||||
key={source.internalName}
|
||||
sourceMetadata={source}
|
||||
federatedConnectors={federatedConnectors}
|
||||
slackCredentials={slackCredentials}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{Object.entries(categorizedSources)
|
||||
.filter(([_, sources]) => sources.length > 0)
|
||||
.map(([category, sources], categoryInd) => (
|
||||
<div key={category} className="pt-8">
|
||||
{dedupedPopular.length > 0 && (
|
||||
<div className="pt-8">
|
||||
<Text as="p" headingH3>
|
||||
{category}
|
||||
Popular
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-4 p-4">
|
||||
{sources.map((source, sourceInd) => (
|
||||
{dedupedPopular.map((source) => (
|
||||
<SourceTileTooltipWrapper
|
||||
preSelect={
|
||||
(searchTerm?.length ?? 0) > 0 &&
|
||||
categoryInd == 0 &&
|
||||
sourceInd == 0
|
||||
}
|
||||
preSelect={false}
|
||||
key={source.internalName}
|
||||
sourceMetadata={source}
|
||||
federatedConnectors={federatedConnectors}
|
||||
@@ -311,7 +288,33 @@ export default function Page() {
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
|
||||
{Object.entries(categorizedSources)
|
||||
.filter(([_, sources]) => sources.length > 0)
|
||||
.map(([category, sources], categoryInd) => (
|
||||
<div key={category} className="pt-8">
|
||||
<Text as="p" headingH3>
|
||||
{category}
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-4 p-4">
|
||||
{sources.map((source, sourceInd) => (
|
||||
<SourceTileTooltipWrapper
|
||||
preSelect={
|
||||
(searchTerm?.length ?? 0) > 0 &&
|
||||
categoryInd == 0 &&
|
||||
sourceInd == 0
|
||||
}
|
||||
key={source.internalName}
|
||||
sourceMetadata={source}
|
||||
federatedConnectors={federatedConnectors}
|
||||
slackCredentials={slackCredentials}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,14 +4,14 @@ import { PersonasTable } from "./PersonaTable";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SubLabel } from "@/components/Field";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
|
||||
import { Persona } from "./interfaces";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { SvgOnyxOctagon } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import { useState, useEffect } from "react";
|
||||
import Pagination from "@/refresh-components/Pagination";
|
||||
|
||||
@@ -120,6 +120,7 @@ function MainContent({
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.AGENTS]!;
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const { personas, totalItems, isLoading, error, refresh } = useAdminPersonas({
|
||||
pageNum: currentPage - 1, // Backend uses 0-indexed pages
|
||||
@@ -127,31 +128,33 @@ export default function Page() {
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgOnyxOctagon} title="Agents" />
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
|
||||
{isLoading && <ThreeDotsLoader />}
|
||||
<SettingsLayouts.Body>
|
||||
{isLoading && <ThreeDotsLoader />}
|
||||
|
||||
{error && (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to load agents"
|
||||
errorMsg={
|
||||
error?.info?.message ||
|
||||
error?.info?.detail ||
|
||||
"An unknown error occurred"
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{error && (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to load agents"
|
||||
errorMsg={
|
||||
error?.info?.message ||
|
||||
error?.info?.detail ||
|
||||
"An unknown error occurred"
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isLoading && !error && (
|
||||
<MainContent
|
||||
personas={personas}
|
||||
totalItems={totalItems}
|
||||
currentPage={currentPage}
|
||||
onPageChange={setCurrentPage}
|
||||
refreshPersonas={refresh}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
{!isLoading && !error && (
|
||||
<MainContent
|
||||
personas={personas}
|
||||
totalItems={totalItems}
|
||||
currentPage={currentPage}
|
||||
onPageChange={setCurrentPage}
|
||||
refreshPersonas={refresh}
|
||||
/>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
@@ -32,6 +32,9 @@ import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.API_KEYS]!;
|
||||
|
||||
function Main() {
|
||||
const {
|
||||
@@ -233,10 +236,11 @@ function Main() {
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="API Keys" icon={SvgKey} />
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,9 +4,8 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import { SlackTokensForm } from "./SlackTokensForm";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
|
||||
export const NewSlackBotForm = () => {
|
||||
const [formValues] = useState({
|
||||
@@ -19,20 +18,19 @@ export const NewSlackBotForm = () => {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<div>
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
|
||||
title="New Slack Bot"
|
||||
/>
|
||||
<CardSection>
|
||||
<div className="p-4">
|
||||
<SlackTokensForm
|
||||
isUpdate={false}
|
||||
initialValues={formValues}
|
||||
router={router}
|
||||
/>
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={SvgSlack} title="New Slack Bot" separator />
|
||||
<SettingsLayouts.Body>
|
||||
<CardSection>
|
||||
<div className="p-4">
|
||||
<SlackTokensForm
|
||||
isUpdate={false}
|
||||
initialValues={formValues}
|
||||
router={router}
|
||||
/>
|
||||
</div>
|
||||
</CardSection>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import {
|
||||
DocumentSetSummary,
|
||||
SlackChannelConfig,
|
||||
ValidSources,
|
||||
} from "@/lib/types";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { DocumentSetSummary, SlackChannelConfig } from "@/lib/types";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
import { FetchAgentsResponse, fetchAgentsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
|
||||
@@ -77,27 +72,28 @@ async function EditslackChannelConfigPage(props: {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-4xl container">
|
||||
<SettingsLayouts.Root>
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
<BackButton />
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon sourceType={ValidSources.Slack} iconSize={32} />}
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSlack}
|
||||
title={
|
||||
slackChannelConfig.is_default
|
||||
? "Edit Default Slack Config"
|
||||
: "Edit Slack Channel Config"
|
||||
}
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slackChannelConfig.slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={assistants}
|
||||
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
|
||||
existingSlackChannelConfig={slackChannelConfig}
|
||||
/>
|
||||
</div>
|
||||
<SettingsLayouts.Body>
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slackChannelConfig.slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={assistants}
|
||||
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
|
||||
existingSlackChannelConfig={slackChannelConfig}
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSetSummary, ValidSources } from "@/lib/types";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import { fetchAgentsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { redirect } from "next/navigation";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
|
||||
async function NewChannelConfigPage(props: {
|
||||
params: Promise<{ "bot-id": string }>;
|
||||
@@ -50,20 +49,22 @@ async function NewChannelConfigPage(props: {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<BackButton />
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />}
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSlack}
|
||||
title="Configure OnyxBot for Slack Channel"
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={agentsResponse[0]}
|
||||
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
|
||||
/>
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={agentsResponse[0]}
|
||||
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,15 +3,14 @@
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SlackBotTable } from "./SlackBotTable";
|
||||
import { useSlackBots } from "./[bot-id]/hooks";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
|
||||
const Main = () => {
|
||||
function Main() {
|
||||
const {
|
||||
data: slackBots,
|
||||
isLoading: isSlackBotsLoading,
|
||||
@@ -73,20 +72,18 @@ const Main = () => {
|
||||
<SlackBotTable slackBots={slackBots} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.SLACK_BOTS]!;
|
||||
|
||||
const Page = () => {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
|
||||
title="Slack Bots"
|
||||
/>
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<InstantSSRAutoRefresh />
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -4,13 +4,15 @@ import { useState } from "react";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import { DocumentIcon2 } from "@/components/icons/icons";
|
||||
import useSWR from "swr";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgLock } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_PROCESSING]!;
|
||||
|
||||
function Main() {
|
||||
const {
|
||||
@@ -149,12 +151,11 @@ function Main() {
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Document Processing"
|
||||
icon={<DocumentIcon2 size={32} className="my-auto" />}
|
||||
/>
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { SvgImage } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import ImageGenerationContent from "./ImageGenerationContent";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.IMAGE_GENERATION]!;
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgImage}
|
||||
title="Image Generation"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Settings for in-chat image generation."
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -187,6 +187,7 @@ export const fetchOllamaModels = async (
|
||||
api_base: apiBase,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
@@ -19,7 +19,10 @@ import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
import { SvgSearch } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.SEARCH_SETTINGS]!;
|
||||
|
||||
export interface EmbeddingDetails {
|
||||
api_key: string;
|
||||
custom_config: any;
|
||||
@@ -141,9 +144,11 @@ function Main() {
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="Search Settings" icon={SvgSearch} />
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,10 @@ import {
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.WEB_SEARCH]!;
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
SEARCH_PROVIDER_DETAILS,
|
||||
@@ -403,8 +406,8 @@ export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
title="Web Search"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
@@ -426,8 +429,8 @@ export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
title="Web Search"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
@@ -832,8 +835,8 @@ export default function Page() {
|
||||
<>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
title="Web Search"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
|
||||
@@ -20,6 +20,7 @@ import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/
|
||||
import { Credential } from "@/lib/connectors/credentials";
|
||||
import { useFederatedConnectors } from "@/lib/hooks";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
|
||||
export default function ConnectorWrapper({
|
||||
connector,
|
||||
@@ -29,6 +30,13 @@ export default function ConnectorWrapper({
|
||||
const searchParams = useSearchParams();
|
||||
const mode = searchParams?.get("mode"); // 'federated' or 'regular'
|
||||
|
||||
useToastFromQuery({
|
||||
oauth_failed: {
|
||||
message: "OAuth authentication failed. Please try again.",
|
||||
type: "error",
|
||||
},
|
||||
});
|
||||
|
||||
// Check if the connector is valid
|
||||
if (!isValidSource(connector)) {
|
||||
return (
|
||||
|
||||
@@ -2,10 +2,6 @@ import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { cookies } from "next/headers";
|
||||
import {
|
||||
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import {
|
||||
CRAFT_OAUTH_COOKIE_NAME,
|
||||
CRAFT_CONFIGURE_PATH,
|
||||
@@ -15,6 +11,7 @@ import { processCookies } from "@/lib/userSS";
|
||||
export const GET = async (request: NextRequest) => {
|
||||
const requestCookies = await cookies();
|
||||
const connector = request.url.includes("gmail") ? "gmail" : "google-drive";
|
||||
|
||||
const callbackEndpoint = `/manage/connector/${connector}/callback`;
|
||||
const url = new URL(buildUrl(callbackEndpoint));
|
||||
url.search = request.nextUrl.search;
|
||||
@@ -26,7 +23,12 @@ export const GET = async (request: NextRequest) => {
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
||||
return NextResponse.redirect(
|
||||
new URL(
|
||||
`/admin/connectors/${connector}?message=oauth_failed`,
|
||||
getDomain(request)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check for build mode OAuth flag (redirects to build admin panel)
|
||||
@@ -40,16 +42,7 @@ export const GET = async (request: NextRequest) => {
|
||||
return redirectResponse;
|
||||
}
|
||||
|
||||
const authCookieName =
|
||||
connector === "gmail"
|
||||
? GMAIL_AUTH_IS_ADMIN_COOKIE_NAME
|
||||
: GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME;
|
||||
|
||||
if (requestCookies.get(authCookieName)?.value?.toLowerCase() === "true") {
|
||||
return NextResponse.redirect(
|
||||
new URL(`/admin/connectors/${connector}`, getDomain(request))
|
||||
);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(new URL("/user/connectors", getDomain(request)));
|
||||
return NextResponse.redirect(
|
||||
new URL(`/admin/connectors/${connector}`, getDomain(request))
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
@@ -9,7 +9,6 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils";
|
||||
import { SvgKey } from "@opal/icons";
|
||||
export default function OAuthCallbackPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [statusMessage, setStatusMessage] = useState("Processing...");
|
||||
|
||||
@@ -6,11 +6,7 @@ import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
|
||||
import {
|
||||
DOCS_ADMINS_PATH,
|
||||
GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
import { TextFormField, SectionHeader } from "@/components/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
@@ -592,11 +588,6 @@ export const DriveAuthSection = ({
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
|
||||
@@ -7,10 +7,7 @@ import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGmailOAuth } from "@/lib/gmail";
|
||||
import {
|
||||
DOCS_ADMINS_PATH,
|
||||
GMAIL_AUTH_IS_ADMIN_COOKIE_NAME,
|
||||
} from "@/lib/constants";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
import { CRAFT_OAUTH_COOKIE_NAME } from "@/app/craft/v1/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField, SectionHeader } from "@/components/Field";
|
||||
@@ -602,9 +599,6 @@ export const GmailAuthSection = ({
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
Cookies.set(GMAIL_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
if (buildMode) {
|
||||
Cookies.set(CRAFT_OAUTH_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { FiDownload } from "react-icons/fi";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import {
|
||||
Table,
|
||||
@@ -17,6 +16,10 @@ import { Card } from "@/components/ui/card";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { SvgDownloadCloud } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DEBUG]!;
|
||||
|
||||
function Main() {
|
||||
const [categories, setCategories] = useState<string[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
@@ -114,13 +117,13 @@ function Main() {
|
||||
);
|
||||
}
|
||||
|
||||
const Page = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={<FiDownload size={32} />} title="Debug Logs" />
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import {
|
||||
import { createGuildConfig } from "@/app/admin/discord-bot/lib";
|
||||
import { DiscordGuildsTable } from "@/app/admin/discord-bot/DiscordGuildsTable";
|
||||
import { BotConfigCard } from "@/app/admin/discord-bot/BotConfigCard";
|
||||
import { SvgDiscordMono } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
function DiscordBotContent() {
|
||||
const { data: guilds, isLoading, error, refreshGuilds } = useDiscordGuilds();
|
||||
@@ -118,11 +118,13 @@ function DiscordBotContent() {
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DISCORD_BOTS]!;
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgDiscordMono}
|
||||
title="Discord Bots"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Connect Onyx to your Discord servers. Users can ask questions directly in Discord channels."
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
import { useState } from "react";
|
||||
import useSWR from "swr";
|
||||
import { SvgArrowExchange } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.INDEX_MIGRATION]!;
|
||||
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Content, ContentAction } from "@opal/layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
@@ -213,8 +216,8 @@ export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgArrowExchange}
|
||||
title="Document Index Migration"
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Monitor the migration from Vespa to OpenSearch and control the active retrieval source."
|
||||
separator
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"use client";
|
||||
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import { Explorer } from "./Explorer";
|
||||
import { Connector } from "@/lib/connectors/connectors";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
|
||||
interface DocumentExplorerPageProps {
|
||||
initialSearchValue: string | undefined;
|
||||
connectors: Connector<any>[];
|
||||
documentSets: DocumentSetSummary[];
|
||||
}
|
||||
|
||||
export default function DocumentExplorerPage({
|
||||
initialSearchValue,
|
||||
connectors,
|
||||
documentSets,
|
||||
}: DocumentExplorerPageProps) {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_EXPLORER]!;
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<Explorer
|
||||
initialSearchValue={initialSearchValue}
|
||||
connectors={connectors}
|
||||
documentSets={documentSets}
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Explorer } from "./Explorer";
|
||||
import { fetchValidFilterInfo } from "@/lib/search/utilsSS";
|
||||
import { SvgZoomIn } from "@opal/icons";
|
||||
import DocumentExplorerPage from "./DocumentExplorerPage";
|
||||
|
||||
export default async function Page(props: {
|
||||
searchParams: Promise<{ [key: string]: string }>;
|
||||
}) {
|
||||
@@ -9,17 +8,10 @@ export default async function Page(props: {
|
||||
const { connectors, documentSets } = await fetchValidFilterInfo();
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
icon={<SvgZoomIn className="stroke-text-04 h-8 w-8" />}
|
||||
title="Document Explorer"
|
||||
/>
|
||||
|
||||
<Explorer
|
||||
initialSearchValue={searchParams.query}
|
||||
connectors={connectors}
|
||||
documentSets={documentSets}
|
||||
/>
|
||||
</>
|
||||
<DocumentExplorerPage
|
||||
initialSearchValue={searchParams.query}
|
||||
connectors={connectors}
|
||||
documentSets={documentSets}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,10 +4,11 @@ import { LoadingAnimation } from "@/components/Loading";
|
||||
import { useMostReactedToDocuments } from "@/lib/hooks";
|
||||
import { DocumentFeedbackTable } from "./DocumentFeedbackTable";
|
||||
import { numPages, numToDisplay } from "./constants";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import Title from "@/components/ui/title";
|
||||
import { SvgThumbsUp } from "@opal/icons";
|
||||
const Main = () => {
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
function Main() {
|
||||
const {
|
||||
data: mostLikedDocuments,
|
||||
isLoading: isMostLikedDocumentsLoading,
|
||||
@@ -57,16 +58,17 @@ const Main = () => {
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_FEEDBACK]!;
|
||||
|
||||
const Page = () => {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgThumbsUp} title="Document Feedback" />
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -5,9 +5,8 @@ import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { refreshDocumentSets, useDocumentSets } from "../hooks";
|
||||
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { BookmarkIcon } from "@/components/icons/icons";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { DocumentSetCreationForm } from "../DocumentSetCreationForm";
|
||||
import { useRouter } from "next/navigation";
|
||||
@@ -69,24 +68,17 @@ function Main({ documentSetId }: { documentSetId: number }) {
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<AdminPageTitle
|
||||
icon={<BookmarkIcon size={32} />}
|
||||
title={documentSet.name}
|
||||
<CardSection>
|
||||
<DocumentSetCreationForm
|
||||
ccPairs={ccPairs}
|
||||
userGroups={userGroups}
|
||||
onClose={() => {
|
||||
refreshDocumentSets();
|
||||
router.push("/admin/documents/sets");
|
||||
}}
|
||||
existingDocumentSet={documentSet}
|
||||
/>
|
||||
|
||||
<CardSection>
|
||||
<DocumentSetCreationForm
|
||||
ccPairs={ccPairs}
|
||||
userGroups={userGroups}
|
||||
onClose={() => {
|
||||
refreshDocumentSets();
|
||||
router.push("/admin/documents/sets");
|
||||
}}
|
||||
existingDocumentSet={documentSet}
|
||||
/>
|
||||
</CardSection>
|
||||
</div>
|
||||
</CardSection>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -95,12 +87,19 @@ export default function Page(props: {
|
||||
}) {
|
||||
const params = use(props.params);
|
||||
const documentSetId = parseInt(params.documentSetId);
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
|
||||
|
||||
return (
|
||||
<>
|
||||
<BackButton />
|
||||
|
||||
<Main documentSetId={documentSetId} />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title="Edit Document Set"
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<Main documentSetId={documentSetId} />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { BookmarkIcon } from "@/components/icons/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import { DocumentSetCreationForm } from "../DocumentSetCreationForm";
|
||||
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { refreshDocumentSets } from "../hooks";
|
||||
@@ -56,19 +55,20 @@ function Main() {
|
||||
);
|
||||
}
|
||||
|
||||
const Page = () => {
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
|
||||
|
||||
return (
|
||||
<>
|
||||
<BackButton />
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<BookmarkIcon size={32} />}
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title="New Document Set"
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { BookmarkIcon, InfoIcon } from "@/components/icons/icons";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
@@ -19,7 +19,8 @@ import { useDocumentSets } from "./hooks";
|
||||
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
|
||||
import { deleteDocumentSet } from "./lib";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import {
|
||||
FiAlertTriangle,
|
||||
FiCheckCircle,
|
||||
@@ -358,7 +359,7 @@ const DocumentSetTable = ({
|
||||
);
|
||||
};
|
||||
|
||||
const Main = () => {
|
||||
function Main() {
|
||||
const {
|
||||
data: documentSets,
|
||||
isLoading: isDocumentSetsLoading,
|
||||
@@ -418,16 +419,17 @@ const Main = () => {
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
|
||||
|
||||
const Page = () => {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={<BookmarkIcon size={32} />} title="Document Sets" />
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ function ConnectorRow({
|
||||
onClick={handleRowClick}
|
||||
>
|
||||
<TableCell className="">
|
||||
<p className="lg:w-[200px] xl:w-[400px] inline-block ellipsis truncate">
|
||||
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
|
||||
{ccPairsIndexingStatus.name}
|
||||
</p>
|
||||
</TableCell>
|
||||
@@ -246,7 +246,7 @@ function FederatedConnectorRow({
|
||||
onClick={handleRowClick}
|
||||
>
|
||||
<TableCell className="">
|
||||
<p className="lg:w-[200px] xl:w-[400px] inline-block ellipsis truncate">
|
||||
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
|
||||
{federatedConnector.name}
|
||||
</p>
|
||||
</TableCell>
|
||||
@@ -293,7 +293,7 @@ export function CCPairIndexingStatusTable({
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
return (
|
||||
<Table className="-mt-8">
|
||||
<Table className="-mt-8 table-fixed">
|
||||
<TableHeader>
|
||||
<ConnectorRow
|
||||
invisible
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { NotebookIcon } from "@/components/icons/icons";
|
||||
import { CCPairIndexingStatusTable } from "./CCPairIndexingStatusTable";
|
||||
import { SearchAndFilterControls } from "./SearchAndFilterControls";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Link from "next/link";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
import Text from "@/components/ui/text";
|
||||
import { useConnectorIndexingStatusWithPagination } from "@/lib/hooks";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
@@ -201,6 +201,8 @@ function Main() {
|
||||
}
|
||||
|
||||
export default function Status() {
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.INDEXING_STATUS]!;
|
||||
|
||||
useToastFromQuery({
|
||||
"connector-created": {
|
||||
message: "Connector created successfully",
|
||||
@@ -213,16 +215,18 @@ export default function Status() {
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
icon={<NotebookIcon size={32} />}
|
||||
title="Existing Connectors"
|
||||
farRightElement={
|
||||
<SettingsLayouts.Root width="full">
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
rightChildren={
|
||||
<Button href="/admin/add-connector">Add Connector</Button>
|
||||
}
|
||||
separator
|
||||
/>
|
||||
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
DatePickerField,
|
||||
FieldLabel,
|
||||
TextArrayField,
|
||||
TextFormField,
|
||||
} from "@/components/Field";
|
||||
import { BrainIcon } from "@/components/icons/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import SwitchField from "@/refresh-components/form/SwitchField";
|
||||
@@ -31,6 +30,9 @@ import KGEntityTypes from "@/app/admin/kg/KGEntityTypes";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.KNOWLEDGE_GRAPH]!;
|
||||
|
||||
function createDomainField(
|
||||
name: string,
|
||||
@@ -324,12 +326,11 @@ export default function Page() {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Knowledge Graph"
|
||||
icon={<BrainIcon size={32} className="my-auto" />}
|
||||
/>
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import SimpleTabs from "@/refresh-components/SimpleTabs";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/components/ui/text";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
@@ -16,8 +16,11 @@ import { toast } from "@/hooks/useToast";
|
||||
import CreateRateLimitModal from "./CreateRateLimitModal";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { SvgGlobe, SvgShield, SvgUser, SvgUsers } from "@opal/icons";
|
||||
import { SvgGlobe, SvgUser, SvgUsers } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.TOKEN_RATE_LIMITS]!;
|
||||
const BASE_URL = "/api/admin/token-rate-limits";
|
||||
const GLOBAL_TOKEN_FETCH_URL = `${BASE_URL}/global`;
|
||||
const USER_TOKEN_FETCH_URL = `${BASE_URL}/users`;
|
||||
@@ -208,9 +211,11 @@ function Main() {
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="Token Rate Limits" icon={SvgShield} />
|
||||
<Main />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ import SimpleTabs from "@/refresh-components/SimpleTabs";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
|
||||
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
|
||||
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
@@ -22,7 +21,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { SvgDownloadCloud, SvgUser, SvgUserPlus } from "@opal/icons";
|
||||
import { SvgDownloadCloud, SvgUserPlus } from "@opal/icons";
|
||||
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.USERS]!;
|
||||
|
||||
interface CountDisplayProps {
|
||||
label: string;
|
||||
value: number | null;
|
||||
@@ -48,7 +51,7 @@ function CountDisplay({ label, value, isLoading }: CountDisplayProps) {
|
||||
);
|
||||
}
|
||||
|
||||
const UsersTables = ({
|
||||
function UsersTables({
|
||||
q,
|
||||
isDownloadingUsers,
|
||||
setIsDownloadingUsers,
|
||||
@@ -56,7 +59,7 @@ const UsersTables = ({
|
||||
q: string;
|
||||
isDownloadingUsers: boolean;
|
||||
setIsDownloadingUsers: (loading: boolean) => void;
|
||||
}) => {
|
||||
}) {
|
||||
const [currentUsersCount, setCurrentUsersCount] = useState<number | null>(
|
||||
null
|
||||
);
|
||||
@@ -236,9 +239,9 @@ const UsersTables = ({
|
||||
});
|
||||
|
||||
return <SimpleTabs tabs={tabs} defaultValue="current" />;
|
||||
};
|
||||
}
|
||||
|
||||
const SearchableTables = () => {
|
||||
function SearchableTables() {
|
||||
const [query, setQuery] = useState("");
|
||||
const [isDownloadingUsers, setIsDownloadingUsers] = useState(false);
|
||||
|
||||
@@ -262,7 +265,7 @@ const SearchableTables = () => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
function AddUserButton() {
|
||||
const [bulkAddUsersModal, setBulkAddUsersModal] = useState(false);
|
||||
@@ -325,13 +328,13 @@ function AddUserButton() {
|
||||
);
|
||||
}
|
||||
|
||||
const Page = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="Manage Users" icon={SvgUser} />
|
||||
<SearchableTables />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<SearchableTables />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { FiDownload } from "react-icons/fi";
|
||||
import { memo, useState } from "react";
|
||||
import { SvgDownload } from "@opal/icons";
|
||||
import { ImageShape } from "@/app/app/services/streamingModels";
|
||||
import { FullImageModal } from "@/app/app/components/files/images/FullImageModal";
|
||||
import { buildImgUrl } from "@/app/app/components/files/images/utils";
|
||||
@@ -24,17 +24,22 @@ const SHAPE_CLASSES: Record<ImageShape, { container: string; image: string }> =
|
||||
},
|
||||
};
|
||||
|
||||
// Used to stop image flashing as images are loaded and response continues
|
||||
const loadedImages = new Set<string>();
|
||||
|
||||
interface InMessageImageProps {
|
||||
fileId: string;
|
||||
fileName?: string;
|
||||
shape?: ImageShape;
|
||||
}
|
||||
|
||||
export function InMessageImage({
|
||||
export const InMessageImage = memo(function InMessageImage({
|
||||
fileId,
|
||||
fileName,
|
||||
shape = DEFAULT_SHAPE,
|
||||
}: InMessageImageProps) {
|
||||
const [fullImageShowing, setFullImageShowing] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(loadedImages.has(fileId));
|
||||
|
||||
const normalizedShape = SHAPE_CLASSES[shape] ? shape : DEFAULT_SHAPE;
|
||||
const { container: shapeContainerClasses, image: shapeImageClasses } =
|
||||
@@ -45,11 +50,15 @@ export function InMessageImage({
|
||||
|
||||
try {
|
||||
const response = await fetch(buildImgUrl(fileId));
|
||||
if (!response.ok) {
|
||||
console.error("Failed to download image:", response.status);
|
||||
return;
|
||||
}
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `image-${fileId}.png`; // You can adjust the filename/extension as needed
|
||||
a.download = fileName || `image-${fileId}.png`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
window.URL.revokeObjectURL(url);
|
||||
@@ -76,7 +85,10 @@ export function InMessageImage({
|
||||
width={1200}
|
||||
height={1200}
|
||||
alt="Chat Message Image"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
onLoad={() => {
|
||||
loadedImages.add(fileId);
|
||||
setImageLoaded(true);
|
||||
}}
|
||||
className={cn(
|
||||
"object-contain object-left overflow-hidden rounded-lg w-full h-full transition-opacity duration-300 cursor-pointer",
|
||||
shapeImageClasses,
|
||||
@@ -94,7 +106,7 @@ export function InMessageImage({
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={FiDownload}
|
||||
icon={SvgDownload}
|
||||
tooltip="Download"
|
||||
onClick={handleDownload}
|
||||
/>
|
||||
@@ -102,4 +114,4 @@ export function InMessageImage({
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
const CHAT_FILE_URL_REGEX = /\/api\/chat\/file\/([^/?#]+)/;
|
||||
const IMAGE_EXTENSIONS = /\.(png|jpe?g|gif|webp|svg|bmp|ico|tiff?)$/i;
|
||||
|
||||
export function buildImgUrl(fileId: string) {
|
||||
return `/api/chat/file/${fileId}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* If `href` points to a chat file and `linkText` ends with an image extension,
|
||||
* returns the file ID. Otherwise returns null.
|
||||
*/
|
||||
export function extractChatImageFileId(
|
||||
href: string | undefined,
|
||||
linkText: string
|
||||
): string | null {
|
||||
if (!href) return null;
|
||||
const match = CHAT_FILE_URL_REGEX.exec(href);
|
||||
if (!match?.[1]) return null;
|
||||
if (!IMAGE_EXTENSIONS.test(linkText)) return null;
|
||||
return match[1];
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
<div className="md:max-w-[37.5rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
|
||||
<div
|
||||
className={
|
||||
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces break-anywhere rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
}
|
||||
onCopy={(e) => {
|
||||
const selection = window.getSelection();
|
||||
|
||||
@@ -14,6 +14,8 @@ import {
|
||||
import { extractCodeText, preprocessLaTeX } from "@/app/app/message/codeUtils";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { transformLinkUri, cn } from "@/lib/utils";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
|
||||
/**
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
@@ -58,17 +60,31 @@ export const useMarkdownComponents = (
|
||||
);
|
||||
|
||||
const anchorCallback = useCallback(
|
||||
(props: any) => (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={state?.setPresentingDocument || (() => {})}
|
||||
docs={state?.docs || []}
|
||||
userFiles={state?.userFiles || []}
|
||||
citations={state?.citations}
|
||||
href={props.href}
|
||||
>
|
||||
{props.children}
|
||||
</MemoizedAnchor>
|
||||
),
|
||||
(props: any) => {
|
||||
const imageFileId = extractChatImageFileId(
|
||||
props.href,
|
||||
String(props.children ?? "")
|
||||
);
|
||||
if (imageFileId) {
|
||||
return (
|
||||
<InMessageImage
|
||||
fileId={imageFileId}
|
||||
fileName={String(props.children ?? "")}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={state?.setPresentingDocument || (() => {})}
|
||||
docs={state?.docs || []}
|
||||
userFiles={state?.userFiles || []}
|
||||
citations={state?.citations}
|
||||
href={props.href}
|
||||
>
|
||||
{props.children}
|
||||
</MemoizedAnchor>
|
||||
);
|
||||
},
|
||||
[
|
||||
state?.docs,
|
||||
state?.userFiles,
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
:root {
|
||||
--app-page-main-content-width: 52.5rem;
|
||||
--block-width-form-input-min: 10rem;
|
||||
|
||||
--container-sm: 42rem;
|
||||
--container-sm-md: 47rem;
|
||||
--container-md: 54.5rem;
|
||||
--container-lg: 62rem;
|
||||
--container-full: 100%;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import BillingInformationPage from "./BillingInformationPage";
|
||||
import { MdOutlineCreditCard } from "react-icons/md";
|
||||
import { SvgCreditCard } from "@opal/icons";
|
||||
|
||||
export interface BillingInformation {
|
||||
stripe_subscription_id: string;
|
||||
@@ -18,12 +18,15 @@ export interface BillingInformation {
|
||||
|
||||
export default function page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgCreditCard}
|
||||
title="Billing Information"
|
||||
icon={<MdOutlineCreditCard size={32} className="my-auto" />}
|
||||
separator
|
||||
/>
|
||||
<BillingInformationPage />
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<BillingInformationPage />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user