Compare commits

..

1 Commits

Author SHA1 Message Date
Nik
48e7428069 chore(helm): remove broken code-interpreter dependency
The code-interpreter Helm chart repo at
https://onyx-dot-app.github.io/code-interpreter/ returns 404,
causing ct lint to fail in CI. Remove it from Chart.yaml
dependencies, Chart.lock, ct.yaml chart-repos, and the CI
workflow's helm repo add step.
2026-02-19 20:17:14 -08:00
121 changed files with 1515 additions and 3606 deletions

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -91,7 +91,6 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Install Redis operator

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -1671,10 +1671,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -13,7 +13,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -373,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -637,14 +637,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -83,11 +83,7 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
GRAPH_API_BASE = f"{DEFAULT_GRAPH_API_HOST}/v1.0"
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
GRAPH_API_MAX_RETRIES = 5
GRAPH_API_RETRYABLE_STATUSES = frozenset({429, 500, 502, 503, 504})
@@ -180,25 +176,6 @@ class CertificateData(BaseModel):
thumbprint: str
def _site_page_in_time_window(
page: dict[str, Any],
start: datetime | None,
end: datetime | None,
) -> bool:
"""Return True if the page's lastModifiedDateTime falls within [start, end]."""
if start is None and end is None:
return True
raw = page.get("lastModifiedDateTime")
if not raw:
return True
if not isinstance(raw, str):
raise ValueError(f"lastModifiedDateTime is not a string: {raw}")
last_modified = datetime.fromisoformat(raw.replace("Z", "+00:00"))
return (start is None or last_modified >= start) and (
end is None or last_modified <= end
)
def sleep_and_retry(
query_obj: ClientQuery, method_name: str, max_retries: int = 3
) -> Any:
@@ -289,12 +266,10 @@ def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData
def acquire_token_for_rest(
msal_app: msal.ConfidentialClientApplication,
sp_tenant_domain: str,
sharepoint_domain_suffix: str,
msal_app: msal.ConfidentialClientApplication, sp_tenant_domain: str
) -> TokenResponse:
token = msal_app.acquire_token_for_client(
scopes=[f"https://{sp_tenant_domain}.{sharepoint_domain_suffix}/.default"]
scopes=[f"https://{sp_tenant_domain}.sharepoint.com/.default"]
)
return TokenResponse.from_json(token)
@@ -409,13 +384,12 @@ def _download_via_graph_api(
drive_id: str,
item_id: str,
bytes_allowed: int,
graph_api_base: str,
) -> bytes:
"""Download a drive item via the Graph API /content endpoint with a byte cap.
Raises SizeCapExceeded if the cap is exceeded.
"""
url = f"{graph_api_base}/drives/{drive_id}/items/{item_id}/content"
url = f"{GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}/content"
headers = {"Authorization": f"Bearer {access_token}"}
with requests.get(
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
@@ -436,7 +410,6 @@ def _convert_driveitem_to_document_with_permissions(
drive_name: str,
ctx: ClientContext | None,
graph_client: GraphClient,
graph_api_base: str,
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
access_token: str | None = None,
@@ -493,7 +466,6 @@ def _convert_driveitem_to_document_with_permissions(
driveitem.drive_id,
driveitem.id,
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
graph_api_base=graph_api_base,
)
except SizeCapExceeded:
logger.warning(
@@ -813,9 +785,6 @@ class SharepointConnector(
sites: list[str] = [],
include_site_pages: bool = True,
include_site_documents: bool = True,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
) -> None:
self.batch_size = batch_size
self.sites = list(sites)
@@ -831,10 +800,6 @@ class SharepointConnector(
self._cached_rest_ctx: ClientContext | None = None
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
self.graph_api_base = f"{self.graph_api_host}/v1.0"
self.sharepoint_domain_suffix = sharepoint_domain_suffix
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
@@ -891,9 +856,8 @@ class SharepointConnector(
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
sp_domain_suffix = self.sharepoint_domain_suffix
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
self._cached_rest_ctx_url = site_url
self._cached_rest_ctx_created_at = time.monotonic()
@@ -1153,36 +1117,76 @@ class SharepointConnector(
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Yield SharePoint site pages (.aspx files) one at a time.
) -> list[dict[str, Any]]:
"""Fetch SharePoint site pages (.aspx files) using the SharePoint Pages API."""
Pages are fetched via the Graph Pages API and yielded lazily as each
API page arrives, so memory stays bounded regardless of total page count.
Time-window filtering is applied per-item before yielding.
"""
# Get the site to extract the site ID
site = self.graph_client.sites.get_by_url(site_descriptor.url)
site.execute_query()
site.execute_query() # Execute the query to actually fetch the data
site_id = site.id
page_url: str | None = (
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
# Get the token acquisition function from the GraphClient
token_data = self._acquire_token()
access_token = token_data.get("access_token")
if not access_token:
raise RuntimeError("Failed to acquire access token")
# Construct the SharePoint Pages API endpoint
# Using API directly, since the Graph Client doesn't support the Pages API
pages_endpoint = f"https://graph.microsoft.com/v1.0/sites/{site_id}/pages/microsoft.graph.sitePage"
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
# Add expand parameter to get canvas layout content
params = {"$expand": "canvasLayout"}
response = requests.get(
pages_endpoint,
headers=headers,
params=params,
timeout=REQUEST_TIMEOUT_SECONDS,
)
params: dict[str, str] | None = {"$expand": "canvasLayout"}
total_yielded = 0
response.raise_for_status()
pages_data = response.json()
all_pages = pages_data.get("value", [])
while page_url:
data = self._graph_api_get_json(page_url, params)
params = None # nextLink already embeds query params
# Handle pagination if there are more pages
# TODO: This accumulates all pages in memory and can be heavy on large tenants.
# We should process each page incrementally to avoid unbounded growth.
while "@odata.nextLink" in pages_data:
next_url = pages_data["@odata.nextLink"]
response = requests.get(
next_url, headers=headers, timeout=REQUEST_TIMEOUT_SECONDS
)
response.raise_for_status()
pages_data = response.json()
all_pages.extend(pages_data.get("value", []))
for page in data.get("value", []):
if not _site_page_in_time_window(page, start, end):
continue
total_yielded += 1
yield page
logger.debug(f"Found {len(all_pages)} site pages in {site_descriptor.url}")
page_url = data.get("@odata.nextLink")
# Filter pages based on time window if specified
if start is not None or end is not None:
filtered_pages: list[dict[str, Any]] = []
for page in all_pages:
page_modified = page.get("lastModifiedDateTime")
if page_modified:
if isinstance(page_modified, str):
page_modified = datetime.fromisoformat(
page_modified.replace("Z", "+00:00")
)
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
if start is not None and page_modified < start:
continue
if end is not None and page_modified > end:
continue
filtered_pages.append(page)
all_pages = filtered_pages
return all_pages
def _acquire_token(self) -> dict[str, Any]:
"""
@@ -1192,7 +1196,7 @@ class SharepointConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
return token
@@ -1265,10 +1269,9 @@ class SharepointConnector(
Performs BFS folder traversal manually, fetching one page of children
at a time so that memory usage stays bounded regardless of drive size.
"""
base = f"{self.graph_api_base}/drives/{drive_id}"
base = f"{GRAPH_API_BASE}/drives/{drive_id}"
if folder_path:
encoded_path = quote(folder_path, safe="/")
start_url = f"{base}/root:/{encoded_path}:/children"
start_url = f"{base}/root:/{folder_path}:/children"
else:
start_url = f"{base}/root/children"
@@ -1326,7 +1329,7 @@ class SharepointConnector(
"""
use_timestamp_token = start is not None and start > _EPOCH
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
initial_url = f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta"
if use_timestamp_token:
assert start is not None # mypy
token = quote(start.isoformat(timespec="seconds"))
@@ -1372,7 +1375,7 @@ class SharepointConnector(
drive_id,
)
yield from self._iter_delta_pages(
initial_url=f"{self.graph_api_base}/drives/{drive_id}/root/delta",
initial_url=f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta",
drive_id=drive_id,
start=start,
end=end,
@@ -1489,7 +1492,7 @@ class SharepointConnector(
sp_private_key = credentials.get("sp_private_key")
sp_certificate_password = credentials.get("sp_certificate_password")
authority_url = f"{self.authority_host}/{sp_directory_id}"
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
logger.info("Using certificate authentication")
@@ -1530,7 +1533,7 @@ class SharepointConnector(
raise ConnectorValidationError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if token is None:
raise ConnectorValidationError("Failed to acquire token for graph")
@@ -1959,7 +1962,6 @@ class SharepointConnector(
self.graph_client,
include_permissions=include_permissions,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
graph_api_base=self.graph_api_base,
access_token=access_token,
)

View File

@@ -50,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -66,15 +63,11 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -83,7 +76,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{self.authority_host}/{teams_directory_id}"
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -98,7 +91,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):

View File

@@ -4940,7 +4940,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -70,8 +70,6 @@ GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
"""
MEMORY_GUIDANCE = """

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -1925,8 +1924,6 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
@@ -1934,7 +1931,6 @@ def get_basic_connector_indexing_status(
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -36,8 +36,6 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
@@ -52,7 +50,6 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.utils.logger import setup_logger
@@ -380,37 +377,6 @@ def create_memory_packets(
return packets
def create_python_tool_packets(
code: str,
stdout: str,
stderr: str,
file_ids: list[str],
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
tool call data so the frontend can display both the code and its output
on page reload."""
packets: list[Packet] = []
placement = Placement(turn_index=turn_index, tab_index=tab_index)
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
packets.append(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
),
)
)
packets.append(Packet(placement=placement, obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
search_docs: list[SavedSearchDoc],
@@ -620,41 +586,6 @@ def translate_assistant_message_to_packets(
)
)
elif tool.in_code_tool_id == PythonTool.__name__:
code = cast(
str,
tool_call.tool_call_arguments.get("code", ""),
)
stdout = ""
stderr = ""
file_ids: list[str] = []
if tool_call.tool_call_response:
try:
response_data = json.loads(tool_call.tool_call_response)
stdout = response_data.get("stdout", "")
stderr = response_data.get("stderr", "")
generated_files = response_data.get(
"generated_files", []
)
file_ids = [
f.get("file_link", "").split("/")[-1]
for f in generated_files
if f.get("file_link")
]
except (json.JSONDecodeError, KeyError):
# Fall back to raw response as stdout
stdout = tool_call.tool_call_response
turn_tool_packets.extend(
create_python_tool_packets(
code=code,
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
turn_index=turn_num,
tab_index=tool_call.tab_index,
)
)
else:
# Custom tool or unknown tool
turn_tool_packets.extend(

View File

@@ -24,7 +24,6 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SAML_CONF_DIR
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
@@ -124,12 +123,9 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
if request.client is None:
raise ValueError("Invalid request for SAML")
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
# to poison SAML redirect URLs (host header poisoning).
parsed_domain = urlparse(WEB_DOMAIN)
http_host = parsed_domain.hostname or request.client.host
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
# Use X-Forwarded headers if available
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
rv: dict[str, Any] = {
"http_host": http_host,

View File

@@ -199,12 +199,6 @@ class PythonToolOverrideKwargs(BaseModel):
chat_files: list[ChatFile] = []
class ImageGenerationToolOverrideKwargs(BaseModel):
"""Override kwargs for image generation tool calls."""
recent_generated_image_file_ids: list[str] = []
class SearchToolRunContext(BaseModel):
emitter: Emitter

View File

@@ -11,14 +11,11 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.app_configs import IMAGE_MODEL_PROVIDER
from onyx.db.image_generation import get_default_image_generation_config
from onyx.file_store.models import ChatFileType
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import load_chat_file_by_id
from onyx.file_store.utils import save_files
from onyx.image_gen.factory import get_image_generation_provider
from onyx.image_gen.factory import validate_credentials
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
@@ -26,7 +23,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
@@ -35,7 +31,6 @@ from onyx.tools.tool_implementations.images.models import (
)
from onyx.tools.tool_implementations.images.models import ImageGenerationResponse
from onyx.tools.tool_implementations.images.models import ImageShape
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -45,10 +40,10 @@ logger = setup_logger()
HEARTBEAT_INTERVAL = 5.0
PROMPT_FIELD = "prompt"
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
NAME = "generate_image"
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
DISPLAY_NAME = "Image Generation"
@@ -64,7 +59,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
) -> None:
super().__init__(emitter=emitter)
self.model = model
self.provider = provider
self.num_imgs = num_imgs
self.img_provider = get_image_generation_provider(
@@ -139,16 +133,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
),
"enum": [shape.value for shape in ImageShape],
},
REFERENCE_IMAGE_FILE_IDS_FIELD: {
"type": "array",
"description": (
"Optional image file IDs to use as reference context for edits/variations. "
"Use the file_id values returned by previous generate_image calls."
),
"items": {
"type": "string",
},
},
},
"required": [PROMPT_FIELD],
},
@@ -164,10 +148,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
def _generate_image(
self,
prompt: str,
shape: ImageShape,
reference_images: list[ReferenceImage] | None = None,
self, prompt: str, shape: ImageShape
) -> tuple[ImageGenerationResponse, Any]:
if shape == ImageShape.LANDSCAPE:
if "gpt-image-1" in self.model:
@@ -188,7 +169,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
model=self.model,
size=size,
n=1,
reference_images=reference_images,
# response_format parameter is not supported for gpt-image-1
response_format=None if "gpt-image-1" in self.model else "b64_json",
)
@@ -251,117 +231,10 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
emit_error_packet=True,
)
def _resolve_reference_image_file_ids(
self,
llm_kwargs: dict[str, Any],
override_kwargs: ImageGenerationToolOverrideKwargs | None,
) -> list[str]:
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
if raw_reference_ids is not None:
if not isinstance(raw_reference_ids, list) or not all(
isinstance(file_id, str) for file_id in raw_reference_ids
):
raise ToolCallException(
message=(
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, "
f"got {type(raw_reference_ids)}"
),
llm_facing_message=(
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
),
)
reference_image_file_ids = [
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
]
elif (
override_kwargs
and override_kwargs.recent_generated_image_file_ids
and self.img_provider.supports_reference_images
):
# If no explicit reference was provided, default to the most recently generated image.
reference_image_file_ids = [
override_kwargs.recent_generated_image_file_ids[-1]
]
else:
reference_image_file_ids = []
# Deduplicate while preserving order.
deduped_reference_image_ids: list[str] = []
seen_ids: set[str] = set()
for file_id in reference_image_file_ids:
if file_id in seen_ids:
continue
seen_ids.add(file_id)
deduped_reference_image_ids.append(file_id)
if not deduped_reference_image_ids:
return []
if not self.img_provider.supports_reference_images:
raise ToolCallException(
message=(
f"Reference images requested but provider '{self.provider}' "
"does not support image-editing context."
),
llm_facing_message=(
"This image provider does not support editing from previous image context. "
"Try text-only generation, or switch to a provider/model that supports image edits."
),
)
max_reference_images = self.img_provider.max_reference_images
if max_reference_images > 0:
return deduped_reference_image_ids[-max_reference_images:]
return deduped_reference_image_ids
def _load_reference_images(
self,
reference_image_file_ids: list[str],
) -> list[ReferenceImage]:
reference_images: list[ReferenceImage] = []
for file_id in reference_image_file_ids:
try:
loaded_file = load_chat_file_by_id(file_id)
except Exception as e:
raise ToolCallException(
message=f"Could not load reference image file '{file_id}': {e}",
llm_facing_message=(
f"Reference image file '{file_id}' could not be loaded. "
"Use file_id values returned by previous generate_image calls."
),
)
if loaded_file.file_type != ChatFileType.IMAGE:
raise ToolCallException(
message=f"Reference file '{file_id}' is not an image",
llm_facing_message=f"Reference file '{file_id}' is not an image.",
)
try:
mime_type = get_image_type_from_bytes(loaded_file.content)
except Exception as e:
raise ToolCallException(
message=f"Unsupported reference image format for '{file_id}': {e}",
llm_facing_message=(
f"Reference image '{file_id}' has an unsupported format. "
"Only PNG, JPEG, GIF, and WEBP are supported."
),
)
reference_images.append(
ReferenceImage(
data=loaded_file.content,
mime_type=mime_type,
)
)
return reference_images
def run(
self,
placement: Placement,
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
override_kwargs: None = None, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
if PROMPT_FIELD not in llm_kwargs:
@@ -374,11 +247,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
prompt = cast(str, llm_kwargs[PROMPT_FIELD])
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
reference_image_file_ids = self._resolve_reference_image_file_ids(
llm_kwargs=llm_kwargs,
override_kwargs=override_kwargs,
)
reference_images = self._load_reference_images(reference_image_file_ids)
# Use threading to generate images in parallel while emitting heartbeats
results: list[tuple[ImageGenerationResponse, Any] | None] = [
@@ -399,7 +267,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
(
prompt,
shape,
reference_images or None,
),
)
for _ in range(self.num_imgs)
@@ -480,7 +347,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
llm_facing_response = json.dumps(
[
{
"file_id": img.file_id,
"revised_prompt": img.revised_prompt,
}
for img in generated_images_metadata

View File

@@ -1,8 +1,5 @@
import json
from collections.abc import Generator
from typing import Literal
from typing import TypedDict
from typing import Union
import requests
from pydantic import BaseModel
@@ -39,37 +36,6 @@ class ExecuteResponse(BaseModel):
files: list[WorkspaceFile]
class StreamOutputEvent(BaseModel):
"""SSE 'output' event: a chunk of stdout or stderr"""
stream: Literal["stdout", "stderr"]
data: str
class StreamResultEvent(BaseModel):
"""SSE 'result' event: final execution result"""
exit_code: int | None
timed_out: bool
duration_ms: int
files: list[WorkspaceFile]
class StreamErrorEvent(BaseModel):
"""SSE 'error' event: execution-level error"""
message: str
StreamEvent = Union[StreamOutputEvent, StreamResultEvent, StreamErrorEvent]
_SSE_EVENT_MAP: dict[str, type[BaseModel]] = {
"output": StreamOutputEvent,
"result": StreamResultEvent,
"error": StreamErrorEvent,
}
class CodeInterpreterClient:
"""Client for Code Interpreter service"""
@@ -79,23 +45,6 @@ class CodeInterpreterClient:
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def _build_payload(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> dict:
payload: dict = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
return payload
def execute(
self,
code: str,
@@ -103,106 +52,25 @@ class CodeInterpreterClient:
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> ExecuteResponse:
"""Execute Python code (batch)"""
"""Execute Python code"""
url = f"{self.base_url}/v1/execute"
payload = self._build_payload(code, stdin, timeout_ms, files)
payload = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
response.raise_for_status()
return ExecuteResponse(**response.json())
def execute_streaming(
self,
code: str,
stdin: str | None = None,
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> Generator[StreamEvent, None, None]:
"""Execute Python code with streaming SSE output.
Yields StreamEvent objects (StreamOutputEvent, StreamResultEvent,
StreamErrorEvent) as execution progresses. Falls back to batch
execution if the streaming endpoint is not available (older
code-interpreter versions).
"""
url = f"{self.base_url}/v1/execute/stream"
payload = self._build_payload(code, stdin, timeout_ms, files)
response = self.session.post(
url,
json=payload,
stream=True,
timeout=timeout_ms / 1000 + 10,
)
if response.status_code == 404:
logger.info(
"Streaming endpoint not available, " "falling back to batch execution"
)
response.close()
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response
) -> Generator[StreamEvent, None, None]:
"""Parse SSE streaming response into StreamEvent objects.
Expected format per event:
event: <type>
data: <json>
<blank line>
"""
event_type: str | None = None
data_lines: list[str] = []
for line in response.iter_lines(decode_unicode=True):
if line is None:
continue
logger.critical(line)
if line == "":
# Blank line marks end of an SSE event
if event_type is not None and data_lines:
data = "\n".join(data_lines)
model_cls = _SSE_EVENT_MAP.get(event_type)
if model_cls is not None:
yield model_cls(**json.loads(data))
else:
logger.warning(f"Unknown SSE event type: {event_type}")
event_type = None
data_lines = []
elif line.startswith("event:"):
event_type = line[len("event:") :].strip()
elif line.startswith("data:"):
data_lines.append(line[len("data:") :].strip())
def _batch_as_stream(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> Generator[StreamEvent, None, None]:
"""Execute via batch endpoint and yield results as stream events."""
result = self.execute(code, stdin, timeout_ms, files)
if result.stdout:
yield StreamOutputEvent(stream="stdout", data=result.stdout)
if result.stderr:
yield StreamOutputEvent(stream="stderr", data=result.stderr)
yield StreamResultEvent(
exit_code=result.exit_code,
timed_out=result.timed_out,
duration_ms=result.duration_ms,
files=result.files,
)
def upload_file(self, file_content: bytes, filename: str) -> str:
"""Upload file to Code Interpreter and return file_id"""
url = f"{self.base_url}/v1/files"

View File

@@ -28,15 +28,6 @@ from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamErrorEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamOutputEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
from onyx.utils.logger import setup_logger
@@ -190,50 +181,19 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
try:
logger.debug(f"Executing code: {code}")
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
for event in client.execute_streaming(
# Execute code with timeout
response = client.execute(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Handle generated files
@@ -242,7 +202,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
for workspace_file in response.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
@@ -298,23 +258,26 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
)
# Emit delta with stdout/stderr and generated files
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=truncated_stdout,
stderr=truncated_stderr,
file_ids=generated_file_ids,
),
)
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
exit_code=response.exit_code,
timed_out=response.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
error=None if response.exit_code == 0 else truncated_stderr,
)
# Serialize result for LLM

View File

@@ -1,4 +1,3 @@
import json
import traceback
from collections import defaultdict
from typing import Any
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import ParallelToolCallResponse
from onyx.tools.models import PythonToolOverrideKwargs
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
return merged_calls
def _extract_image_file_ids_from_tool_response_message(
message: str,
) -> list[str]:
try:
parsed_message = json.loads(message)
except json.JSONDecodeError:
return []
parsed_items: list[Any] = (
parsed_message if isinstance(parsed_message, list) else [parsed_message]
)
file_ids: list[str] = []
for item in parsed_items:
if not isinstance(item, dict):
continue
file_id = item.get("file_id")
if isinstance(file_id, str):
file_ids.append(file_id)
return file_ids
def _extract_recent_generated_image_file_ids(
message_history: list[ChatMessageSimple],
) -> list[str]:
tool_name_by_tool_call_id: dict[str, str] = {}
recent_image_file_ids: list[str] = []
seen_file_ids: set[str] = set()
for message in message_history:
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
for tool_call in message.tool_calls:
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
continue
if (
message.message_type != MessageType.TOOL_CALL_RESPONSE
or not message.tool_call_id
):
continue
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
if tool_name != ImageGenerationTool.NAME:
continue
for file_id in _extract_image_file_ids_from_tool_response_message(
message.message
):
if file_id in seen_file_ids:
continue
seen_file_ids.add(file_id)
recent_image_file_ids.append(file_id)
return recent_image_file_ids
def _safe_run_single_tool(
tool: Tool,
tool_call: ToolCallKickoff,
@@ -386,9 +324,6 @@ def run_tool_calls(
url_to_citation: dict[str, int] = {
url: citation_num for citation_num, url in citation_mapping.items()
}
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
message_history
)
# Prepare all tool calls with their override_kwargs
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
@@ -405,7 +340,6 @@ def run_tool_calls(
| WebSearchToolOverrideKwargs
| OpenURLToolOverrideKwargs
| PythonToolOverrideKwargs
| ImageGenerationToolOverrideKwargs
| MemoryToolOverrideKwargs
| None
) = None
@@ -454,10 +388,6 @@ def run_tool_calls(
override_kwargs = PythonToolOverrideKwargs(
chat_files=chat_files or [],
)
elif isinstance(tool, ImageGenerationTool):
override_kwargs = ImageGenerationToolOverrideKwargs(
recent_generated_image_file_ids=recent_generated_image_file_ids
)
elif isinstance(tool, MemoryTool):
override_kwargs = MemoryToolOverrideKwargs(
user_name=(

View File

@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.0
onyx-devtools==0.5.7
# via onyx
openai==2.14.0
# via

View File

@@ -1,281 +0,0 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -217,8 +217,8 @@ class TestAutoModeSyncFeature:
),
additional_visible_models=[
SimpleKnownModel(
name="claude-haiku-4-5",
display_name="Claude Haiku 4.5",
name="claude-3-5-haiku-latest",
display_name="Claude 3.5 Haiku",
)
],
),
@@ -260,7 +260,7 @@ class TestAutoModeSyncFeature:
# Anthropic models should NOT be present
assert "claude-3-5-sonnet-latest" not in model_names
assert "claude-haiku-4-5" not in model_names
assert "claude-3-5-haiku-latest" not in model_names
finally:
db_session.rollback()
@@ -485,7 +485,7 @@ class TestAutoModeSyncFeature:
# Provider 2 (Anthropic) config
provider_2_default_model = "claude-3-5-sonnet-latest"
provider_2_additional_models = ["claude-haiku-4-5"]
provider_2_additional_models = ["claude-3-5-haiku-latest"]
# Create mock recommendations with both providers
mock_recommendations = LLMRecommendations(

View File

@@ -281,22 +281,15 @@ def test_anthropic_prompt_caching_reduces_costs(
Anthropic requires explicit cache_control parameters.
"""
# Prompt caching support is model/account specific.
# Allow override via env var and otherwise try a few non-retired candidates.
anthropic_prompt_cache_models_env = os.environ.get("ANTHROPIC_PROMPT_CACHE_MODELS")
if anthropic_prompt_cache_models_env:
candidate_models = [
model.strip()
for model in anthropic_prompt_cache_models_env.split(",")
if model.strip()
]
else:
candidate_models = [
"claude-haiku-4-5-20251001",
"claude-sonnet-4-5-20250929",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
]
# Create Anthropic LLM
# NOTE: prompt caching support is model-specific; `claude-3-haiku-20240307` is known
# to return cache_creation/cache_read usage metrics, while some newer aliases may not.
llm = LitellmLLM(
api_key=os.environ["ANTHROPIC_API_KEY"],
model_provider="anthropic",
model_name="claude-3-haiku-20240307",
max_input_tokens=200000,
)
import random
import string
@@ -322,107 +315,79 @@ def test_anthropic_prompt_caching_reduces_costs(
UserMessage(role="user", content=long_context)
]
unavailable_models: list[str] = []
non_caching_models: list[str] = []
# First call - creates cache
print("\n=== First call (cache creation) ===")
question1: list[ChatCompletionMessage] = [
UserMessage(role="user", content="What are the main topics discussed?")
]
for model_name in candidate_models:
llm = LitellmLLM(
api_key=os.environ["ANTHROPIC_API_KEY"],
model_provider="anthropic",
model_name=model_name,
max_input_tokens=200000,
)
# First call - creates cache
print(f"\n=== First call (cache creation) model={model_name} ===")
question1: list[ChatCompletionMessage] = [
UserMessage(
role="user",
content="Reply with exactly one lowercase word: topics",
)
]
processed_messages1, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question1,
continuation=False,
)
try:
response1 = llm.invoke(prompt=processed_messages1, max_tokens=8)
except Exception as e:
error_str = str(e).lower()
if (
"not_found_error" in error_str
or "model_not_found" in error_str
or ('"type":"not_found_error"' in error_str and "model:" in error_str)
):
unavailable_models.append(model_name)
continue
raise
cost1 = completion_cost(
completion_response=response1.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage1 = response1.usage
print(f"Response 1 usage: {usage1}")
print(f"Cost 1: ${cost1:.10f}")
# Wait to ensure cache is available
time.sleep(2)
# Second call with same context - should use cache
print(f"\n=== Second call (cache read) model={model_name} ===")
question2: list[ChatCompletionMessage] = [
UserMessage(
role="user",
content="Reply with exactly one lowercase word: neural",
)
]
processed_messages2, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question2,
continuation=False,
)
response2 = llm.invoke(prompt=processed_messages2, max_tokens=8)
cost2 = completion_cost(
completion_response=response2.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage2 = response2.usage
print(f"Response 2 usage: {usage2}")
print(f"Cost 2: ${cost2:.10f}")
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
print(f"Cache read tokens (call 2): {cache_read_tokens}")
print(f"Cost reduction: ${cost1 - cost2:.10f}")
# Model is available but does not expose Anthropic cache usage metrics
if cache_creation_tokens <= 0 or cache_read_tokens <= 0:
non_caching_models.append(model_name)
continue
# Cost should be lower on second call
assert (
cost2 < cost1
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
return
pytest.skip(
"No Anthropic model available with observable prompt-cache metrics. "
f"Tried models={candidate_models}, unavailable={unavailable_models}, non_caching={non_caching_models}"
# Apply prompt caching
processed_messages1, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question1,
continuation=False,
)
response1 = llm.invoke(prompt=processed_messages1)
cost1 = completion_cost(
completion_response=response1.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage1 = response1.usage
print(f"Response 1 usage: {usage1}")
print(f"Cost 1: ${cost1:.10f}")
# Wait to ensure cache is available
time.sleep(2)
# Second call with same context - should use cache
print("\n=== Second call (cache read) ===")
question2: list[ChatCompletionMessage] = [
UserMessage(role="user", content="Can you elaborate on neural networks?")
]
# Apply prompt caching (same cacheable prefix)
processed_messages2, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question2,
continuation=False,
)
response2 = llm.invoke(prompt=processed_messages2)
cost2 = completion_cost(
completion_response=response2.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage2 = response2.usage
print(f"Response 2 usage: {usage2}")
print(f"Cost 2: ${cost2:.10f}")
# Verify caching occurred
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
print(f"Cache read tokens (call 2): {cache_read_tokens}")
print(f"Cost reduction: ${cost1 - cost2:.10f}")
# For Anthropic, we should see cache creation on first call and cache reads on second
assert (
cache_creation_tokens > 0
), f"Expected cache creation tokens on first call. Got: {cache_creation_tokens}"
assert (
cache_read_tokens > 0
), f"Expected cache read tokens on second call. Got: {cache_read_tokens}"
# Cost should be lower on second call
assert (
cost2 < cost1
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
@pytest.mark.skipif(
not os.environ.get(VERTEX_CREDENTIALS_ENV),

View File

@@ -13,7 +13,6 @@ from litellm.types.utils import ImageResponse
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
from onyx.llm.interfaces import LLMConfig
@@ -63,7 +62,6 @@ class MockImageGenerationProvider(
size: str, # noqa: ARG002
n: int, # noqa: ARG002
quality: str | None = None, # noqa: ARG002
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any, # noqa: ARG002
) -> ImageResponse:
image_data = self._images.pop(0)

View File

@@ -943,18 +943,10 @@ from onyx.db.tools import get_builtin_tool
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.server.features.projects.api import upload_user_files
from onyx.server.query_and_chat.chat_backend import get_chat_session
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
from tests.external_dependency_unit.answer.stream_test_utils import create_placement
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.mock_llm import LLMAnswerResponse
from tests.external_dependency_unit.mock_llm import LLMToolCallResponse
from tests.external_dependency_unit.mock_llm import use_mock_llm
@@ -990,27 +982,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self._respond_json(
200, {"file_id": f"mock-ci-file-{self.server._file_counter}"}
)
elif self.path == "/v1/execute/streaming":
if self.server.streaming_enabled:
self._respond_sse(
[
(
"output",
{"stream": "stdout", "data": "mock output\n"},
),
(
"result",
{
"exit_code": 0,
"timed_out": False,
"duration_ms": 50,
"files": [],
},
),
]
)
else:
self._respond_json(404, {"error": "not found"})
elif self.path == "/v1/execute":
self._respond_json(
200,
@@ -1048,17 +1019,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(payload)
def _respond_sse(self, events: list[tuple[str, dict[str, Any]]]) -> None:
frames = []
for event_type, data in events:
frames.append(f"event: {event_type}\ndata: {json.dumps(data)}\n\n")
payload = "".join(frames).encode()
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
pass
@@ -1070,7 +1030,6 @@ class MockCodeInterpreterServer(HTTPServer):
super().__init__(("localhost", 0), _MockCIHandler)
self.captured_requests: list[CapturedRequest] = []
self._file_counter = 0
self.streaming_enabled: bool = True
@property
def url(self) -> str:
@@ -1201,226 +1160,17 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed via streaming, staged file cleaned up
# Verify: file uploaded, code executed, staged file cleaned up
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/streaming"))
== 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/streaming"
)[0].json_body()
execute_body = mock_ci_server.get_requests(method="POST", path="/v1/execute")[
0
].json_body()
assert execute_body["code"] == code
assert len(execute_body["files"]) == 1
assert execute_body["files"][0]["path"] == "data.csv"
def test_code_interpreter_replay_packets_include_code_and_output(
db_session: Session,
mock_ci_server: MockCodeInterpreterServer,
_attach_python_tool_to_default_persona: None,
initialize_file_store: None, # noqa: ARG001
) -> None:
"""After a code interpreter message completes, retrieving the message
via translate_assistant_message_to_packets should emit PythonToolStart
(containing the executed code) and PythonToolDelta (containing
stdout/stderr), not generic CustomTool packets."""
mock_ci_server.captured_requests.clear()
mock_ci_server._file_counter = 0
mock_url = mock_ci_server.url
user = create_test_user(db_session, "ci_replay_test")
chat_session = create_chat_session(db_session=db_session, user=user)
code = 'x = 2 + 2\nprint(f"Result: {x}")'
msg_req = SendMessageRequest(
message="Calculate 2 + 2",
chat_session_id=chat_session.id,
stream=True,
)
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
with (
use_mock_llm() as mock_llm,
patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
mock_url,
),
patch(
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
mock_url,
),
):
answer_tokens = ["The ", "result ", "is ", "4."]
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
try:
handler = StreamTestBuilder(llm_controller=mock_llm)
stream = handle_stream_message_objects(
new_msg_req=msg_req, user=user, db_session=db_session
)
# First packet is always MessageResponseIDInfo
next(stream)
# Phase 1: LLM requests python tool execution.
handler.add_response(
LLMToolCallResponse(
tool_name="python",
tool_call_id="call_replay_test",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
),
forward=2,
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolDelta(stdout="mock output\n", stderr="", file_ids=[]),
),
forward=False,
).expect(
Packet(
placement=create_placement(0),
obj=SectionEnd(),
),
forward=False,
).run_and_validate(
stream=stream
)
# Phase 2: LLM produces a final answer after tool execution.
handler.add_response(
LLMAnswerResponse(answer_tokens=answer_tokens)
).expect_agent_response(
answer_tokens=answer_tokens,
turn_index=1,
).run_and_validate(
stream=stream
)
with pytest.raises(StopIteration):
next(stream)
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Retrieve the chat session through the same endpoint the frontend uses
chat_detail = get_chat_session(
session_id=chat_session.id,
user=user,
db_session=db_session,
)
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/streaming"))
== 1
)
# The response contains `packets` — a list of packet-lists, one per
# assistant message. We should have exactly one assistant message.
assert (
len(chat_detail.packets) == 1
), f"Expected 1 assistant packet list, got {len(chat_detail.packets)}"
packets = chat_detail.packets[0]
# Extract PythonToolStart packets these must contain the code
start_packets = [p for p in packets if isinstance(p.obj, PythonToolStart)]
assert len(start_packets) == 1, (
f"Expected 1 PythonToolStart packet, got {len(start_packets)}. "
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
)
start_obj = start_packets[0].obj
assert isinstance(start_obj, PythonToolStart)
assert start_obj.code == code
# Extract PythonToolDelta packets these must contain stdout/stderr
delta_packets = [p for p in packets if isinstance(p.obj, PythonToolDelta)]
assert len(delta_packets) >= 1, (
f"Expected at least 1 PythonToolDelta packet, got {len(delta_packets)}. "
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
)
# The mock CI server returns "mock output\n" as stdout
delta_obj = delta_packets[0].obj
assert isinstance(delta_obj, PythonToolDelta)
assert "mock output" in delta_obj.stdout
def test_code_interpreter_streaming_fallback_to_batch(
db_session: Session,
mock_ci_server: MockCodeInterpreterServer,
_attach_python_tool_to_default_persona: None,
initialize_file_store: None, # noqa: ARG001
) -> None:
"""When the streaming endpoint is not available (older code-interpreter),
execute_streaming should fall back to the batch /v1/execute endpoint."""
mock_ci_server.captured_requests.clear()
mock_ci_server._file_counter = 0
mock_ci_server.streaming_enabled = False
mock_url = mock_ci_server.url
user = create_test_user(db_session, "ci_fallback_test")
chat_session = create_chat_session(db_session=db_session, user=user)
code = 'print("fallback test")'
msg_req = SendMessageRequest(
message="Print fallback test",
chat_session_id=chat_session.id,
stream=True,
)
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
with (
use_mock_llm() as mock_llm,
patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
mock_url,
),
patch(
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
mock_url,
),
):
mock_llm.add_response(
LLMToolCallResponse(
tool_name="python",
tool_call_id="call_fallback",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
)
mock_llm.forward_till_end()
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
try:
packets = list(
handle_stream_message_objects(
new_msg_req=msg_req, user=user, db_session=db_session
)
)
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
mock_ci_server.streaming_enabled = True
# Streaming was attempted first (returned 404), then fell back to batch
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/streaming"))
== 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
# Verify output still made it through
delta_packets = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, PythonToolDelta)
]
assert len(delta_packets) >= 1
assert "mock output" in delta_packets[0].obj.stdout

View File

@@ -13,9 +13,9 @@ from tests.integration.common_utils.test_models import DATestUser
class APIKeyManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: DATestUser | None = None,
) -> DATestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
@@ -25,7 +25,11 @@ class APIKeyManager:
api_key_response = requests.post(
f"{API_SERVER_URL}/admin/api-key",
json=api_key_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
@@ -44,21 +48,29 @@ class APIKeyManager:
@staticmethod
def delete(
api_key: DATestAPIKey,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
@@ -66,8 +78,8 @@ class APIKeyManager:
@staticmethod
def verify(
api_key: DATestAPIKey,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action

View File

@@ -17,6 +17,7 @@ from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import DocumentSyncStatus
from tests.integration.common_utils.config import api_config
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
@@ -27,10 +28,10 @@ from tests.integration.common_utils.test_models import DATestUser
def _cc_pair_creator(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
@@ -39,12 +40,17 @@ def _cc_pair_creator(
connector_credential_pair_metadata = api.ConnectorCredentialPairMetadata(
name=name, access_type=access_type, groups=groups or []
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
api_response: api.StatusResponseInt = (
api_instance.associate_credential_to_connector(
connector_id,
credential_id,
connector_credential_pair_metadata,
_headers=user_performing_action.headers,
_headers=headers,
)
)
@@ -61,7 +67,6 @@ def _cc_pair_creator(
class CCPairManager:
@staticmethod
def create_from_scratch(
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
@@ -69,25 +74,26 @@ class CCPairManager:
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
user_performing_action=user_performing_action,
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
refresh_freq=refresh_freq,
)
credential = CredentialManager.create(
user_performing_action=user_performing_action,
credential_json=credential_json,
name=name,
source=source,
curator_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
cc_pair = _cc_pair_creator(
connector_id=connector.id,
@@ -103,10 +109,10 @@ class CCPairManager:
def create(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
cc_pair = _cc_pair_creator(
connector_id=connector_id,
@@ -121,31 +127,39 @@ class CCPairManager:
@staticmethod
def pause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "PAUSED"},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def unpause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "ACTIVE"},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
@@ -154,18 +168,26 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json=cc_pair_identifier.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def get_single(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> CCPairFullInfo | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
cc_pair_json = response.json()
@@ -174,11 +196,15 @@ class CCPairManager:
@staticmethod
def get_indexing_status_by_id(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> ConnectorIndexingStatusLite | None:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -193,11 +219,15 @@ class CCPairManager:
@staticmethod
def get_indexing_statuses(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatusLite]:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -211,11 +241,15 @@ class CCPairManager:
@staticmethod
def get_connector_statuses(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [ConnectorStatus(**status) for status in response.json()]
@@ -223,8 +257,8 @@ class CCPairManager:
@staticmethod
def verify(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_connector_statuses(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
@@ -251,7 +285,7 @@ class CCPairManager:
def run_once(
cc_pair: DATestCCPair,
from_beginning: bool,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
@@ -261,15 +295,19 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
json=body,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -304,9 +342,9 @@ class CCPairManager:
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -355,8 +393,8 @@ class CCPairManager:
def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
@@ -392,22 +430,30 @@ class CCPairManager:
@staticmethod
def prune(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def last_pruned(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
response_str = response.json()
@@ -425,8 +471,8 @@ class CCPairManager:
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
@@ -450,7 +496,7 @@ class CCPairManager:
@staticmethod
def sync(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""This function triggers a permission sync.
Naming / intent of this function probably could use improvement, but currently it's letting
@@ -458,14 +504,22 @@ class CCPairManager:
"""
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if result.status_code != 409:
result.raise_for_status()
group_sync_result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if group_sync_result.status_code != 409:
group_sync_result.raise_for_status()
@@ -474,11 +528,15 @@ class CCPairManager:
@staticmethod
def get_doc_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
doc_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
doc_sync_response.raise_for_status()
doc_sync_response_str = doc_sync_response.json()
@@ -495,11 +553,15 @@ class CCPairManager:
@staticmethod
def get_group_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
group_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
group_sync_response.raise_for_status()
group_sync_response_str = group_sync_response.json()
@@ -516,11 +578,15 @@ class CCPairManager:
@staticmethod
def get_doc_sync_statuses(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DocumentSyncStatus]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
doc_sync_statuses: list[DocumentSyncStatus] = []
@@ -547,9 +613,9 @@ class CCPairManager:
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
number_of_updated_docs: int = 0,
user_performing_action: DATestUser | None = None,
# Sometimes waiting for a group sync is not necessary
should_wait_for_group_sync: bool = True,
# Sometimes waiting for a vespa sync is not necessary
@@ -637,8 +703,8 @@ class CCPairManager:
@staticmethod
def wait_for_deletion_completion(
user_performing_action: DATestUser,
cc_pair_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.

View File

@@ -17,6 +17,7 @@ from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import StreamingType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestChatMessage
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
@@ -73,9 +74,9 @@ class StreamPacketData(TypedDict, total=False):
class ChatSessionManager:
@staticmethod
def create(
user_performing_action: DATestUser,
persona_id: int = 0,
description: str = "Test chat session",
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
@@ -83,7 +84,11 @@ class ChatSessionManager:
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",
json=chat_session_creation_req.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
@@ -95,8 +100,8 @@ class ChatSessionManager:
def send_message(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -121,12 +126,19 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
response = requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=headers,
stream=True,
cookies=user_performing_action.cookies,
cookies=cookies,
)
streamed_response = ChatSessionManager.analyze_response(response)
@@ -155,9 +167,9 @@ class ChatSessionManager:
def send_message_with_disconnect(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
disconnect_after_packets: int = 0,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -196,14 +208,21 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
packets_received = 0
with requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=headers,
stream=True,
cookies=user_performing_action.cookies,
cookies=cookies,
) as response:
for line in response.iter_lines():
if not line:
@@ -340,11 +359,15 @@ class ChatSessionManager:
@staticmethod
def get_chat_history(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -364,7 +387,7 @@ class ChatSessionManager:
def create_chat_message_feedback(
message_id: int,
is_positive: bool,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
feedback_text: str | None = None,
predefined_feedback: str | None = None,
) -> None:
@@ -376,14 +399,18 @@ class ChatSessionManager:
"feedback_text": feedback_text,
"predefined_feedback": predefined_feedback,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Delete a chat session and all its related records (messages, agent data, etc.)
@@ -393,14 +420,18 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def soft_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Soft delete a chat session (marks as deleted but keeps in database).
@@ -411,14 +442,18 @@ class ChatSessionManager:
# or make a direct call with hard_delete=False parameter via a new endpoint
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def hard_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Hard delete a chat session (completely removes from database).
@@ -427,14 +462,18 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def verify_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been deleted by attempting to retrieve it.
@@ -443,7 +482,11 @@ class ChatSessionManager:
"""
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@@ -451,7 +494,7 @@ class ChatSessionManager:
@staticmethod
def verify_soft_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
@@ -461,7 +504,11 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 200:
@@ -473,7 +520,7 @@ class ChatSessionManager:
@staticmethod
def verify_hard_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been hard deleted (completely removed from DB).
@@ -483,7 +530,11 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# For hard delete, even with include_deleted=true, the record should not exist

View File

@@ -8,6 +8,7 @@ from onyx.db.enums import AccessType
from onyx.server.documents.models import ConnectorUpdateRequest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestUser
@@ -15,13 +16,13 @@ from tests.integration.common_utils.test_models import DATestUser
class ConnectorManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
@@ -50,7 +51,11 @@ class ConnectorManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector",
json=connector_update_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -68,33 +73,45 @@ class ConnectorManager:
@staticmethod
def edit(
connector: DATestConnector,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
json=connector.model_dump(exclude={"id"}),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
connector: DATestConnector,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestConnector]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [
@@ -110,12 +127,15 @@ class ConnectorManager:
@staticmethod
def get(
connector_id: int,
user_performing_action: DATestUser,
connector_id: int, user_performing_action: DATestUser | None = None
) -> DATestConnector:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
conn = response.json()

View File

@@ -6,6 +6,7 @@ import requests
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
@@ -13,13 +14,13 @@ from tests.integration.common_utils.test_models import DATestUser
class CredentialManager:
@staticmethod
def create(
user_performing_action: DATestUser,
credential_json: dict[str, Any] | None = None,
admin_public: bool = True,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
curator_public: bool = True,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCredential:
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
@@ -35,7 +36,11 @@ class CredentialManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/credential",
json=credential_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -52,46 +57,61 @@ class CredentialManager:
@staticmethod
def edit(
credential: DATestCredential,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
request = credential.model_dump(include={"name", "credential_json"})
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
json=request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
credential: DATestCredential,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def get(
credential_id: int,
user_performing_action: DATestUser,
credential_id: int, user_performing_action: DATestUser | None = None
) -> CredentialSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return CredentialSnapshot(**response.json())
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[CredentialSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/manage/credential",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [CredentialSnapshot(**cred) for cred in response.json()]
@@ -99,8 +119,8 @@ class CredentialManager:
@staticmethod
def verify(
credential: DATestCredential,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_credentials = CredentialManager.get_all(user_performing_action)
for fetched_credential in all_credentials:

View File

@@ -10,6 +10,7 @@ from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import DATestAPIKey
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
@@ -21,9 +22,9 @@ from tests.integration.common_utils.vespa import vespa_fixture
def _verify_document_permissions(
retrieved_doc: dict,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
) -> None:
acl_keys = set(retrieved_doc.get("access_control_list", {}).keys())
print(f"ACL keys: {acl_keys}")
@@ -35,11 +36,12 @@ def _verify_document_permissions(
" does not have the PUBLIC ACL key"
)
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if doc_creating_user is not None:
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if group_names is not None:
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
@@ -99,9 +101,9 @@ class DocumentManager:
@staticmethod
def seed_dummy_docs(
cc_pair: DATestCCPair,
api_key: DATestAPIKey,
num_docs: int = NUM_DOCS,
document_ids: list[str] | None = None,
api_key: DATestAPIKey | None = None,
) -> list[SimpleTestDocument]:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_ids is None:
@@ -116,13 +118,12 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(
f"Seeding docs for api_key_id={api_key.api_key_id} completed successfully."
)
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding docs for api_key_id={api_key_id} completed successfully.")
return [
SimpleTestDocument(
id=document["document"]["id"],
@@ -135,8 +136,8 @@ class DocumentManager:
def seed_doc_with_content(
cc_pair: DATestCCPair,
content: str,
api_key: DATestAPIKey,
document_id: str | None = None,
api_key: DATestAPIKey | None = None,
metadata: dict | None = None,
) -> SimpleTestDocument:
# Use provided document_ids if available, otherwise generate random UUIDs
@@ -152,13 +153,12 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(
f"Seeding doc for api_key_id={api_key.api_key_id} completed successfully."
)
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding doc for api_key_id={api_key_id} completed successfully.")
return SimpleTestDocument(
id=document["document"]["id"],
@@ -169,11 +169,11 @@ class DocumentManager:
def verify(
vespa_client: vespa_fixture,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
# If None, will not check doc sets or groups
# If empty list, will check for empty doc sets or groups
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
verify_deleted: bool = False,
) -> None:
doc_ids = [document.id for document in cc_pair.documents]
@@ -212,9 +212,9 @@ class DocumentManager:
_verify_document_permissions(
retrieved_doc,
cc_pair,
doc_creating_user,
doc_set_names,
group_names,
doc_creating_user,
)
@staticmethod
@@ -268,11 +268,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def list_all_ingestion_docs(
api_key: DATestAPIKey,
api_key: DATestAPIKey | None = None,
) -> list[dict]:
response = requests.get(
f"{API_SERVER_URL}/onyx-api/ingestion",
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
return response.json()
@@ -280,11 +280,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def delete(
document_id: str,
api_key: DATestAPIKey,
api_key: DATestAPIKey | None = None,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/onyx-api/ingestion/{document_id}",
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(f"Deleted document {document_id} successfully.")

View File

@@ -3,6 +3,7 @@ import requests
from ee.onyx.server.query_and_chat.models import SearchFullResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -10,7 +11,7 @@ class DocumentSearchManager:
@staticmethod
def search_documents(
query: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[str]:
"""
Search for documents using the EE search API.
@@ -30,7 +31,11 @@ class DocumentSearchManager:
result = requests.post(
url=f"{API_SERVER_URL}/search/send-search-message",
json=search_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
result_json = result.json()

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestDocumentSet
from tests.integration.common_utils.test_models import DATestUser
@@ -14,7 +15,6 @@ from tests.integration.common_utils.test_models import DATestUser
class DocumentSetManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
cc_pair_ids: list[int] | None = None,
@@ -22,6 +22,7 @@ class DocumentSetManager:
users: list[str] | None = None,
groups: list[int] | None = None,
federated_connectors: list[dict[str, Any]] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestDocumentSet:
if name is None:
name = f"test_doc_set_{str(uuid4())}"
@@ -39,7 +40,11 @@ class DocumentSetManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -58,7 +63,7 @@ class DocumentSetManager:
@staticmethod
def edit(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
doc_set_update_request = {
"id": document_set.id,
@@ -72,7 +77,11 @@ class DocumentSetManager:
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_update_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@@ -80,22 +89,30 @@ class DocumentSetManager:
@staticmethod
def delete(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestDocumentSet]:
response = requests.get(
f"{API_SERVER_URL}/manage/document-set",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [
@@ -115,8 +132,8 @@ class DocumentSetManager:
@staticmethod
def wait_for_sync(
user_performing_action: DATestUser,
document_sets_to_check: list[DATestDocumentSet] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
# wait for document sets to be synced
start = time.time()
@@ -158,8 +175,8 @@ class DocumentSetManager:
@staticmethod
def verify(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
doc_sets = DocumentSetManager.get_all(user_performing_action)
for doc_set in doc_sets:

View File

@@ -10,6 +10,7 @@ import requests
from onyx.file_store.models import FileDescriptor
from onyx.server.documents.models import FileUploadResponse
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -17,9 +18,13 @@ class FileManager:
@staticmethod
def upload_files(
files: List[Tuple[str, IO]],
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> Tuple[List[FileDescriptor], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
files_param = []
@@ -62,11 +67,15 @@ class FileManager:
@staticmethod
def fetch_uploaded_file(
file_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bytes:
response = requests.get(
f"{API_SERVER_URL}/chat/file/{file_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.content

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestImageGenerationConfig
from tests.integration.common_utils.test_models import DATestUser
@@ -25,7 +26,6 @@ def _serialize_custom_config(
class ImageGenerationConfigManager:
@staticmethod
def create(
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
provider: str = "openai",
@@ -35,6 +35,7 @@ class ImageGenerationConfigManager:
deployment_name: str | None = None,
custom_config: dict[str, Any] | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config with new credentials."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -52,7 +53,11 @@ class ImageGenerationConfigManager:
"custom_config": _serialize_custom_config(custom_config),
"is_default": is_default,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -69,13 +74,13 @@ class ImageGenerationConfigManager:
@staticmethod
def create_from_provider(
source_llm_provider_id: int,
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config by cloning from an existing LLM provider."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -91,7 +96,11 @@ class ImageGenerationConfigManager:
"deployment_name": deployment_name,
"is_default": is_default,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -107,12 +116,16 @@ class ImageGenerationConfigManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestImageGenerationConfig]:
"""Get all image generation configs."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [DATestImageGenerationConfig(**config) for config in response.json()]
@@ -120,12 +133,16 @@ class ImageGenerationConfigManager:
@staticmethod
def get_credentials(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> dict:
"""Get credentials for an image generation config."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/credentials",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.json()
@@ -134,13 +151,13 @@ class ImageGenerationConfigManager:
def update(
image_provider_id: str,
model_name: str,
user_performing_action: DATestUser,
provider: str | None = None,
api_key: str | None = None,
source_llm_provider_id: int | None = None,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Update an existing image generation config."""
payload: dict = {
@@ -161,10 +178,14 @@ class ImageGenerationConfigManager:
f"Got: source_llm_provider_id={source_llm_provider_id}, provider={provider}, api_key={'***' if api_key else None}"
)
headers = {**GENERAL_HEADERS}
if user_performing_action:
headers.update(user_performing_action.headers)
response = requests.put(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
json=payload,
headers=user_performing_action.headers,
headers=headers,
)
if not response.ok:
print(f"Update failed with status {response.status_code}: {response.text}")
@@ -183,32 +204,40 @@ class ImageGenerationConfigManager:
@staticmethod
def delete(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""Delete an image generation config."""
response = requests.delete(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def set_default(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""Set an image generation config as the default."""
response = requests.post(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/default",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def verify(
config: DATestImageGenerationConfig,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
"""Verify that a config exists (or doesn't exist if verify_deleted=True)."""
all_configs = ImageGenerationConfigManager.get_all(user_performing_action)

View File

@@ -14,6 +14,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
@@ -85,9 +86,9 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_page(
cc_pair_id: int,
user_performing_action: DATestUser,
page: int = 0,
page_size: int = 10,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[IndexAttemptSnapshot]:
query_params: dict[str, str | int] = {
"page_num": page,
@@ -100,7 +101,11 @@ class IndexAttemptManager:
)
response = requests.get(
url=url,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -112,7 +117,7 @@ class IndexAttemptManager:
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
@@ -129,9 +134,9 @@ class IndexAttemptManager:
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
user_performing_action: DATestUser,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
@@ -159,7 +164,7 @@ class IndexAttemptManager:
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
@@ -185,8 +190,8 @@ class IndexAttemptManager:
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = time.monotonic()
@@ -218,15 +223,19 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()

View File

@@ -8,6 +8,7 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -15,7 +16,6 @@ from tests.integration.common_utils.test_models import DATestUser
class LLMProviderManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
provider: str | None = None,
api_key: str | None = None,
@@ -26,8 +26,13 @@ class LLMProviderManager:
personas: list[int] | None = None,
is_public: bool | None = None,
set_as_default: bool = True,
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
print(f"Seeding LLM Providers for {user_performing_action.email}...")
email = "Unknown"
if user_performing_action:
email = user_performing_action.email
print(f"Seeding LLM Providers for {email}...")
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
@@ -55,7 +60,11 @@ class LLMProviderManager:
llm_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
json=llm_provider.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
llm_response.raise_for_status()
response_data = llm_response.json()
@@ -77,7 +86,11 @@ class LLMProviderManager:
if set_as_default:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -86,22 +99,30 @@ class LLMProviderManager:
@staticmethod
def delete(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
@@ -109,8 +130,8 @@ class LLMProviderManager:
@staticmethod
def verify(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
for fetched_llm_provider in all_llm_providers:

View File

@@ -7,6 +7,7 @@ from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
@@ -15,7 +16,6 @@ from tests.integration.common_utils.test_models import DATestUser
class PersonaManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -34,6 +34,7 @@ class PersonaManager:
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[str] | None = None,
user_performing_action: DATestUser | None = None,
display_priority: int | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
@@ -66,7 +67,11 @@ class PersonaManager:
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
persona_data = response.json()
@@ -95,7 +100,6 @@ class PersonaManager:
@staticmethod
def edit(
persona: DATestPersona,
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -113,6 +117,7 @@ class PersonaManager:
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
@@ -146,7 +151,11 @@ class PersonaManager:
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
updated_persona_data = response.json()
@@ -178,11 +187,15 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [FullPersonaSnapshot(**persona) for persona in response.json()]
@@ -190,11 +203,15 @@ class PersonaManager:
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [FullPersonaSnapshot(**response.json())]
@@ -202,7 +219,7 @@ class PersonaManager:
@staticmethod
def verify(
persona: DATestPersona,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_one(
persona_id=persona.id,
@@ -371,11 +388,15 @@ class PersonaManager:
@staticmethod
def delete(
persona: DATestPersona,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@@ -384,14 +405,18 @@ class PersonaLabelManager:
@staticmethod
def create(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaLabel:
response = requests.post(
f"{API_SERVER_URL}/persona/labels",
json={
"name": label.name,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
response_data = response.json()
@@ -400,11 +425,15 @@ class PersonaLabelManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestPersonaLabel]:
response = requests.get(
f"{API_SERVER_URL}/persona/labels",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [DATestPersonaLabel(**label) for label in response.json()]
@@ -412,14 +441,18 @@ class PersonaLabelManager:
@staticmethod
def update(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaLabel:
response = requests.patch(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
json={
"label_name": label.name,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return label
@@ -427,18 +460,22 @@ class PersonaLabelManager:
@staticmethod
def delete(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def verify(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
all_labels = PersonaLabelManager.get_all(user_performing_action)
for fetched_label in all_labels:

View File

@@ -6,6 +6,7 @@ from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.features.projects.models import UserProjectSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -19,7 +20,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/create",
params={"name": name},
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return UserProjectSnapshot.model_validate(response.json())
@@ -31,7 +32,7 @@ class ProjectManager:
"""Get all projects for a user via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -44,7 +45,7 @@ class ProjectManager:
"""Delete a project via API."""
response = requests.delete(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
return response.status_code == 204
@@ -56,7 +57,7 @@ class ProjectManager:
"""Verify that a project has been deleted by ensuring it's not in list."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
projects = [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -65,12 +66,16 @@ class ProjectManager:
@staticmethod
def verify_files_unlinked(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""Verify that all files have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return True
@@ -82,12 +87,16 @@ class ProjectManager:
@staticmethod
def verify_chat_sessions_unlinked(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""Verify that all chat sessions have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return True
@@ -135,12 +144,16 @@ class ProjectManager:
@staticmethod
def get_project_files(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> List[UserFileSnapshot]:
"""Get all files associated with a project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return []
@@ -157,7 +170,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/{project_id}/instructions",
json={"instructions": instructions},
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return (response.json() or {}).get("instructions") or ""

View File

@@ -10,18 +10,19 @@ from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
class QueryHistoryManager:
@staticmethod
def get_query_history_page(
user_performing_action: DATestUser,
page_num: int = 0,
page_size: int = 10,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[ChatSessionMinimal]:
query_params: dict[str, str | int] = {
"page_num": page_num,
@@ -36,7 +37,11 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -48,20 +53,24 @@ class QueryHistoryManager:
@staticmethod
def get_chat_session_admin(
chat_session_id: UUID | str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> ChatSessionSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return ChatSessionSnapshot(**response.json())
@staticmethod
def get_query_history_as_csv(
user_performing_action: DATestUser,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> tuple[CaseInsensitiveDict[str], str]:
query_params: dict[str, str | int] = {}
if start_time:
@@ -71,7 +80,11 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.headers, response.content.decode()

View File

@@ -5,6 +5,7 @@ from typing import Optional
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
@@ -12,9 +13,13 @@ from tests.integration.common_utils.test_models import DATestUser
class SettingsManager:
@staticmethod
def get_settings(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> tuple[Dict[str, Any], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
response = requests.get(
@@ -33,9 +38,13 @@ class SettingsManager:
@staticmethod
def update_settings(
settings: DATestSettings,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> tuple[Dict[str, Any], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
payload = settings.model_dump()
@@ -56,7 +65,7 @@ class SettingsManager:
@staticmethod
def get_setting(
key: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> Optional[Any]:
settings, error = SettingsManager.get_settings(user_performing_action)
if error:

View File

@@ -8,6 +8,7 @@ from onyx.server.manage.models import AllUsersResponse
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -25,11 +26,15 @@ def generate_auth_token() -> str:
class TenantManager:
@staticmethod
def get_all_users(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> AllUsersResponse:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -45,8 +50,7 @@ class TenantManager:
@staticmethod
def verify_user_in_tenant(
user: DATestUser,
user_performing_action: DATestUser,
user: DATestUser, user_performing_action: DATestUser | None = None
) -> None:
all_users = TenantManager.get_all_users(user_performing_action)
for accepted_user in all_users.accepted:

View File

@@ -1,6 +1,7 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
@@ -8,11 +9,15 @@ from tests.integration.common_utils.test_models import DATestUser
class ToolManager:
@staticmethod
def list_tools(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestTool]:
response = requests.get(
url=f"{API_SERVER_URL}/tool",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [

View File

@@ -7,8 +7,6 @@ import requests
from requests import HTTPError
from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import ANONYMOUS_USER_UUID
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import UserInfo
@@ -27,23 +25,6 @@ def build_email(name: str) -> str:
class UserManager:
@staticmethod
def get_anonymous_user() -> DATestUser:
"""Get a DATestUser representing the anonymous user.
Anonymous users are real users in the database with LIMITED role.
They don't have login cookies - requests are made with GENERAL_HEADERS.
The anonymous_user_enabled setting must be True for these requests to work.
"""
return DATestUser(
id=ANONYMOUS_USER_UUID,
email=ANONYMOUS_USER_EMAIL,
password="",
headers=GENERAL_HEADERS,
role=UserRole.LIMITED,
is_active=True,
)
@staticmethod
def create(
name: str | None = None,
@@ -246,12 +227,12 @@ class UserManager:
@staticmethod
def get_user_page(
user_performing_action: DATestUser,
page_num: int = 0,
page_size: int = 10,
search_query: str | None = None,
role_filter: list[UserRole] | None = None,
is_active_filter: bool | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[FullUserSnapshot]:
query_params: dict[str, str | list[str] | int] = {
"page_num": page_num,
@@ -266,7 +247,11 @@ class UserManager:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()

View File

@@ -5,6 +5,7 @@ import requests
from ee.onyx.server.user_group.models import UserGroup
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import DATestUserGroup
@@ -13,10 +14,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
class UserGroupManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
user_ids: list[str] | None = None,
cc_pair_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestUserGroup:
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
@@ -28,7 +29,11 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group",
json=request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
test_user_group = DATestUserGroup(
@@ -42,23 +47,31 @@ class UserGroupManager:
@staticmethod
def edit(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
json=user_group.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -66,7 +79,7 @@ class UserGroupManager:
def add_users(
user_group: DATestUserGroup,
user_ids: list[str],
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> DATestUserGroup:
request = {
"user_ids": user_ids,
@@ -75,7 +88,11 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
json=request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -90,8 +107,8 @@ class UserGroupManager:
def set_curator_status(
test_user_group: DATestUserGroup,
user_to_set_as_curator: DATestUser,
user_performing_action: DATestUser,
is_curator: bool = True,
user_performing_action: DATestUser | None = None,
) -> None:
set_curator_request = {
"user_id": user_to_set_as_curator.id,
@@ -100,17 +117,25 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
json=set_curator_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[UserGroup]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]
@@ -118,8 +143,8 @@ class UserGroupManager:
@staticmethod
def verify(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_user_groups = UserGroupManager.get_all(user_performing_action)
for fetched_user_group in all_user_groups:
@@ -142,8 +167,8 @@ class UserGroupManager:
@staticmethod
def wait_for_sync(
user_performing_action: DATestUser,
user_groups_to_check: list[DATestUserGroup] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
while True:
@@ -173,7 +198,7 @@ class UserGroupManager:
@staticmethod
def wait_for_deletion_completion(
user_groups_to_check: list[DATestUserGroup],
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}

View File

@@ -88,8 +88,11 @@ def reset() -> None:
@pytest.fixture
def new_admin_user(reset: None) -> DATestUser: # noqa: ARG001
return UserManager.create(name=ADMIN_USER_NAME)
def new_admin_user(reset: None) -> DATestUser | None: # noqa: ARG001
try:
return UserManager.create(name=ADMIN_USER_NAME)
except Exception:
return None
@pytest.fixture
@@ -179,18 +182,18 @@ def reset_multitenant() -> None:
@pytest.fixture
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
@pytest.fixture
def image_generation_config(
admin_user: DATestUser,
admin_user: DATestUser | None,
) -> DATestImageGenerationConfig:
"""Create a default image generation config for tests."""
return ImageGenerationConfigManager.create(
user_performing_action=admin_user,
is_default=True,
user_performing_action=admin_user,
)

View File

@@ -60,44 +60,3 @@ def test_me_endpoint_returns_authenticated_user_info(
assert data.get("is_anonymous_user") is not True
assert data["email"] == admin_user.email
assert data["role"] == "admin"
def test_anonymous_user_can_access_persona_when_enabled(
reset: None, # noqa: ARG001
) -> None:
"""Verify that anonymous users can access limited endpoints when enabled."""
admin_user: DATestUser = UserManager.create(name="admin_user")
SettingsManager.update_settings(
DATestSettings(anonymous_user_enabled=True),
user_performing_action=admin_user,
)
anon_user = UserManager.get_anonymous_user()
response = requests.get(
f"{API_SERVER_URL}/persona",
headers=anon_user.headers,
)
assert response.status_code == 200
def test_anonymous_user_denied_persona_when_disabled(
reset: None, # noqa: ARG001
) -> None:
"""Verify that anonymous users cannot access endpoints when disabled."""
admin_user: DATestUser = UserManager.create(name="admin_user")
SettingsManager.update_settings(
DATestSettings(anonymous_user_enabled=False),
user_performing_action=admin_user,
)
anon_user = UserManager.get_anonymous_user()
response = requests.get(
f"{API_SERVER_URL}/persona",
headers=anon_user.headers,
)
# 403 is returned - BasicAuthenticationError uses HTTP 403 for all auth failures
assert response.status_code == 403

View File

@@ -11,8 +11,8 @@ from tests.integration.common_utils.test_models import DATestUser
def _verify_index_attempt_pagination(
cc_pair_id: int,
index_attempt_ids: list[int],
user_performing_action: DATestUser,
page_size: int = 5,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_attempts: list[int] = []
last_time_started = None # Track the last time_started seen

View File

@@ -207,9 +207,7 @@ def test_mcp_search_respects_acl_filters(
cc_pair_ids=[restricted_cc_pair.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_performing_action=admin_user, user_groups_to_check=[user_group]
)
UserGroupManager.wait_for_sync([user_group], user_performing_action=admin_user)
restricted_doc_content = "MCP restricted knowledge base document"
_seed_document_and_wait_for_indexing(

View File

@@ -14,11 +14,11 @@ from tests.integration.tests.query_history.utils import (
def _verify_query_history_pagination(
chat_sessions: list[DAQueryHistoryEntry],
user_performing_action: DATestUser,
page_size: int = 5,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_sessions: list[str] = []

View File

@@ -5,6 +5,7 @@ import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@@ -58,7 +59,7 @@ def test_add_users_to_group_invalid_user(reset: None) -> None: # noqa: ARG001
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
json={"user_ids": [invalid_user_id]},
headers=admin_user.headers,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -9,11 +9,11 @@ from tests.integration.common_utils.test_models import DATestUser
# to verify that the pagination and filtering works as expected.
def _verify_user_pagination(
users: list[DATestUser],
user_performing_action: DATestUser,
page_size: int = 5,
search_query: str | None = None,
role_filter: list[UserRole] | None = None,
is_active_filter: bool | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_users: list[FullUserSnapshot] = []

View File

@@ -1,268 +0,0 @@
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_enumerate_ad_groups_paginated,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_iter_graph_collection,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_normalize_email,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
AD_GROUP_ENUMERATION_THRESHOLD,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import GroupsResult
MODULE = "ee.onyx.external_permissions.sharepoint.permission_utils"
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_token() -> str:
return "fake-token"
def _make_graph_page(
items: list[dict[str, Any]],
next_link: str | None = None,
) -> dict[str, Any]:
page: dict[str, Any] = {"value": items}
if next_link:
page["@odata.nextLink"] = next_link
return page
# ---------------------------------------------------------------------------
# _normalize_email
# ---------------------------------------------------------------------------
def test_normalize_email_strips_onmicrosoft() -> None:
assert _normalize_email("user@contoso.onmicrosoft.com") == "user@contoso.com"
def test_normalize_email_noop_for_normal_domain() -> None:
assert _normalize_email("user@contoso.com") == "user@contoso.com"
# ---------------------------------------------------------------------------
# _iter_graph_collection
# ---------------------------------------------------------------------------
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_single_page(mock_get: MagicMock) -> None:
mock_get.return_value = _make_graph_page([{"id": "1"}, {"id": "2"}])
items = list(_iter_graph_collection("https://graph/items", _fake_token))
assert items == [{"id": "1"}, {"id": "2"}]
mock_get.assert_called_once()
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_multi_page(mock_get: MagicMock) -> None:
mock_get.side_effect = [
_make_graph_page([{"id": "1"}], next_link="https://graph/items?page=2"),
_make_graph_page([{"id": "2"}]),
]
items = list(_iter_graph_collection("https://graph/items", _fake_token))
assert items == [{"id": "1"}, {"id": "2"}]
assert mock_get.call_count == 2
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_empty(mock_get: MagicMock) -> None:
mock_get.return_value = _make_graph_page([])
assert list(_iter_graph_collection("https://graph/items", _fake_token)) == []
# ---------------------------------------------------------------------------
# _enumerate_ad_groups_paginated
# ---------------------------------------------------------------------------
def _mock_graph_get_for_enumeration(
groups: list[dict[str, Any]],
members_by_group: dict[str, list[dict[str, Any]]],
) -> Generator[dict[str, Any], None, None]:
"""Return a side_effect function for _graph_api_get that serves
groups on the /groups URL and members on /groups/{id}/members URLs."""
def side_effect(
url: str,
get_access_token: Any, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
if "/members" in url:
group_id = url.split("/groups/")[1].split("/members")[0]
return _make_graph_page(members_by_group.get(group_id, []))
return _make_graph_page(groups)
return side_effect # type: ignore[return-value]
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_yields_groups(mock_get: MagicMock) -> None:
groups = [
{"id": "g1", "displayName": "Engineering"},
{"id": "g2", "displayName": "Marketing"},
]
members = {
"g1": [{"userPrincipalName": "alice@contoso.com"}],
"g2": [{"mail": "bob@contoso.onmicrosoft.com"}],
}
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, members)
results = list(
_enumerate_ad_groups_paginated(
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
)
)
assert len(results) == 2
eng = next(r for r in results if r.id == "Engineering_g1")
assert eng.user_emails == ["alice@contoso.com"]
mkt = next(r for r in results if r.id == "Marketing_g2")
assert mkt.user_emails == ["bob@contoso.com"]
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_skips_already_resolved(mock_get: MagicMock) -> None:
groups = [{"id": "g1", "displayName": "Engineering"}]
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
results = list(
_enumerate_ad_groups_paginated(
_fake_token,
already_resolved={"Engineering_g1"},
graph_api_base=GRAPH_API_BASE,
)
)
assert results == []
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_circuit_breaker(mock_get: MagicMock) -> None:
"""Enumeration stops after AD_GROUP_ENUMERATION_THRESHOLD groups."""
over_limit = AD_GROUP_ENUMERATION_THRESHOLD + 5
groups = [{"id": f"g{i}", "displayName": f"Group{i}"} for i in range(over_limit)]
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
results = list(
_enumerate_ad_groups_paginated(
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
)
)
assert len(results) <= AD_GROUP_ENUMERATION_THRESHOLD
# ---------------------------------------------------------------------------
# get_sharepoint_external_groups
# ---------------------------------------------------------------------------
def _stub_role_assignment_resolution(
groups_to_emails: dict[str, set[str]],
) -> tuple[MagicMock, MagicMock]:
"""Return (mock_sleep_and_retry, mock_recursive) pre-configured to
simulate role-assignment group resolution."""
mock_sleep = MagicMock()
mock_recursive = MagicMock(
return_value=GroupsResult(
groups_to_emails=groups_to_emails,
found_public_group=False,
)
)
return mock_sleep, mock_recursive
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_default_skips_ad_enumeration(
mock_sleep: MagicMock, mock_recursive: MagicMock # noqa: ARG001
) -> None:
mock_recursive.return_value = GroupsResult(
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
found_public_group=False,
)
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
graph_api_base=GRAPH_API_BASE,
)
assert len(results) == 1
assert results[0].id == "SiteGroup_abc"
assert results[0].user_emails == ["alice@contoso.com"]
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_enumerate_all_includes_ad_groups(
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
mock_enum: MagicMock,
) -> None:
from ee.onyx.db.external_perm import ExternalUserGroup
mock_recursive.return_value = GroupsResult(
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
found_public_group=False,
)
mock_enum.return_value = [
ExternalUserGroup(id="ADGroup_xyz", user_emails=["bob@contoso.com"]),
]
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
get_access_token=_fake_token,
enumerate_all_ad_groups=True,
graph_api_base=GRAPH_API_BASE,
)
assert len(results) == 2
ids = {r.id for r in results}
assert ids == {"SiteGroup_abc", "ADGroup_xyz"}
mock_enum.assert_called_once()
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_enumerate_all_without_token_skips(
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
mock_enum: MagicMock,
) -> None:
"""Even if enumerate_all_ad_groups=True, no token means skip."""
mock_recursive.return_value = GroupsResult(
groups_to_emails={},
found_public_group=False,
)
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
get_access_token=None,
enumerate_all_ad_groups=True,
graph_api_base=GRAPH_API_BASE,
)
assert results == []
mock_enum.assert_not_called()

View File

@@ -2,10 +2,8 @@
from unittest.mock import MagicMock
from onyx.chat.chat_utils import _build_tool_call_response_history_message
from onyx.chat.chat_utils import get_custom_agent_prompt
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
class TestGetCustomAgentPrompt:
@@ -152,21 +150,3 @@ class TestGetCustomAgentPrompt:
# Should return None because replace_base_system_prompt=True
assert result is None
class TestBuildToolCallResponseHistoryMessage:
def test_image_tool_uses_generated_images(self) -> None:
message = _build_tool_call_response_history_message(
tool_name="generate_image",
generated_images=[{"file_id": "img-1", "revised_prompt": "p1"}],
tool_call_response=None,
)
assert message == '[{"file_id": "img-1", "revised_prompt": "p1"}]'
def test_non_image_tool_uses_placeholder(self) -> None:
message = _build_tool_call_response_history_message(
tool_name="web_search",
generated_images=None,
tool_call_response='{"raw":"value"}',
)
assert message == TOOL_CALL_RESPONSE_CROSS_MESSAGE

View File

@@ -206,7 +206,6 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
drive_name: str,
ctx: Any, # noqa: ARG001
graph_client: Any, # noqa: ARG001
graph_api_base: str, # noqa: ARG001
include_permissions: bool, # noqa: ARG001
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
access_token: str | None = None, # noqa: ARG001

View File

@@ -1,10 +1,5 @@
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
from onyx.tools.tool_runner import _merge_tool_calls
@@ -312,65 +307,3 @@ class TestMergeToolCalls:
assert len(result) == 1
# String should be converted to list item
assert result[0].tool_args["queries"] == ["single_query", "q2"]
class TestImageHistoryExtraction:
def test_extracts_image_file_ids_from_json_response(self) -> None:
msg = (
'[{"file_id":"img-1","revised_prompt":"v1"},'
'{"file_id":"img-2","revised_prompt":"v2"}]'
)
assert _extract_image_file_ids_from_tool_response_message(msg) == [
"img-1",
"img-2",
]
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
history = [
ChatMessageSimple(
message="",
token_count=1,
message_type=MessageType.ASSISTANT,
tool_calls=[
ToolCallSimple(
tool_call_id="call_1",
tool_name="generate_image",
tool_arguments={"prompt": "test"},
token_count=1,
)
],
),
ChatMessageSimple(
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
token_count=1,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id="call_1",
),
]
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
def test_ignores_non_image_tool_responses(self) -> None:
history = [
ChatMessageSimple(
message="",
token_count=1,
message_type=MessageType.ASSISTANT,
tool_calls=[
ToolCallSimple(
tool_call_id="call_1",
tool_name="web_search",
tool_arguments={"queries": ["q"]},
token_count=1,
)
],
),
ChatMessageSimple(
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
token_count=1,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id="call_1",
),
]
assert _extract_recent_generated_image_file_ids(history) == []

View File

@@ -10,7 +10,7 @@ from onyx.tools.utils import explicit_tool_calling_supported
(LlmProviderNames.ANTHROPIC, "claude-4-sonnet-20250514", True),
(
"another-provider",
"claude-haiku-4-5-20251001",
"claude-3-haiku-20240307",
True,
),
(

View File

@@ -1,172 +0,0 @@
"""Unit tests for CodeInterpreterClient streaming-to-batch fallback.
When the streaming endpoint (/v1/execute/stream) returns 404 — e.g. because the
code-interpreter service is an older version that doesn't support streaming — the
client should transparently fall back to the batch endpoint (/v1/execute) and
convert the batch response into the same stream-event interface.
"""
from __future__ import annotations
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamOutputEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
def _make_batch_response(
stdout: str = "",
stderr: str = "",
exit_code: int = 0,
timed_out: bool = False,
duration_ms: int = 50,
) -> MagicMock:
"""Build a mock ``requests.Response`` for the batch /v1/execute endpoint."""
resp = MagicMock()
resp.status_code = 200
resp.raise_for_status = MagicMock()
resp.json.return_value = {
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
"timed_out": timed_out,
"duration_ms": duration_ms,
"files": [],
}
return resp
def _make_404_response() -> MagicMock:
"""Build a mock ``requests.Response`` that returns 404 (streaming not found)."""
resp = MagicMock()
resp.status_code = 404
return resp
def test_execute_streaming_fallback_to_batch_on_404() -> None:
"""When /v1/execute/stream returns 404, the client should fall back to
/v1/execute and yield equivalent StreamEvent objects."""
client = CodeInterpreterClient(base_url="http://fake:9000")
stream_resp = _make_404_response()
batch_resp = _make_batch_response(
stdout="hello world\n",
stderr="a warning\n",
)
urls_called: list[str] = []
def mock_post(url: str, **_kwargs: object) -> MagicMock:
urls_called.append(url)
if url.endswith("/v1/execute/stream"):
return stream_resp
if url.endswith("/v1/execute"):
return batch_resp
raise AssertionError(f"Unexpected URL: {url}")
with patch.object(client.session, "post", side_effect=mock_post):
events = list(client.execute_streaming(code="print('hello world')"))
# Streaming endpoint was attempted first, then batch
assert len(urls_called) == 2
assert urls_called[0].endswith("/v1/execute/stream")
assert urls_called[1].endswith("/v1/execute")
# The 404 response must be closed before making the batch call
stream_resp.close.assert_called_once()
# _batch_as_stream yields: stdout event, stderr event, result event
assert len(events) == 3
assert isinstance(events[0], StreamOutputEvent)
assert events[0].stream == "stdout"
assert events[0].data == "hello world\n"
assert isinstance(events[1], StreamOutputEvent)
assert events[1].stream == "stderr"
assert events[1].data == "a warning\n"
assert isinstance(events[2], StreamResultEvent)
assert events[2].exit_code == 0
assert not events[2].timed_out
assert events[2].duration_ms == 50
assert events[2].files == []
def test_execute_streaming_fallback_stdout_only() -> None:
"""Fallback with only stdout (no stderr) should yield two events:
one StreamOutputEvent for stdout and one StreamResultEvent."""
client = CodeInterpreterClient(base_url="http://fake:9000")
stream_resp = _make_404_response()
batch_resp = _make_batch_response(stdout="result: 42\n")
def mock_post(url: str, **_kwargs: object) -> MagicMock:
if url.endswith("/v1/execute/stream"):
return stream_resp
if url.endswith("/v1/execute"):
return batch_resp
raise AssertionError(f"Unexpected URL: {url}")
with patch.object(client.session, "post", side_effect=mock_post):
events = list(client.execute_streaming(code="print(42)"))
# No stderr → only stdout + result
assert len(events) == 2
assert isinstance(events[0], StreamOutputEvent)
assert events[0].stream == "stdout"
assert events[0].data == "result: 42\n"
assert isinstance(events[1], StreamResultEvent)
assert events[1].exit_code == 0
def test_execute_streaming_fallback_preserves_files_param() -> None:
"""When falling back, the files parameter must be forwarded to the
batch endpoint so staged files are still available for execution."""
client = CodeInterpreterClient(base_url="http://fake:9000")
stream_resp = _make_404_response()
batch_resp = _make_batch_response(stdout="ok\n")
captured_payloads: list[dict] = []
def mock_post(url: str, **kwargs: object) -> MagicMock:
if "json" in kwargs:
captured_payloads.append(kwargs["json"]) # type: ignore[arg-type]
if url.endswith("/v1/execute/stream"):
return stream_resp
if url.endswith("/v1/execute"):
return batch_resp
raise AssertionError(f"Unexpected URL: {url}")
files_input = [{"path": "data.csv", "file_id": "file-abc123"}]
with patch.object(client.session, "post", side_effect=mock_post):
events = list(
client.execute_streaming(
code="import pandas",
files=files_input,
)
)
# Both the streaming attempt and the batch fallback should include files
assert len(captured_payloads) == 2
for payload in captured_payloads:
assert payload["files"] == files_input
assert payload["code"] == "import pandas"
# Should still yield valid events
assert any(isinstance(e, StreamResultEvent) for e in events)

View File

@@ -12,8 +12,7 @@ chart-repos:
- postgresql=https://cloudnative-pg.github.io/charts
- redis=https://ot-container-kit.github.io/helm-charts
- minio=https://charts.min.io/
- code-interpreter=https://onyx-dot-app.github.io/python-sandbox/
# have seen postgres take 10 min to pull ... so 15 min seems like a good timeout?
helm-extra-args: --debug --timeout 900s

View File

@@ -4,6 +4,24 @@ log_format custom_main '$remote_addr - $remote_user [$time_local] "$request" '
'"$http_user_agent" "$http_x_forwarded_for" '
'rt=$request_time';
# Map X-Forwarded-Proto or fallback to $scheme
map $http_x_forwarded_proto $forwarded_proto {
default $http_x_forwarded_proto;
"" $scheme;
}
# Map X-Forwarded-Host or fallback to $host
map $http_x_forwarded_host $forwarded_host {
default $http_x_forwarded_host;
"" $host;
}
# Map X-Forwarded-Port or fallback to server port
map $http_x_forwarded_port $forwarded_port {
default $http_x_forwarded_port;
"" $server_port;
}
upstream api_server {
# fail_timeout=0 means we always retry an upstream even if it failed
# to return a good HTTP response
@@ -41,10 +59,9 @@ server {
# misc headers
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-Forwarded-Proto $forwarded_proto;
proxy_set_header X-Forwarded-Host $forwarded_host;
proxy_set_header X-Forwarded-Port $forwarded_port;
proxy_set_header Host $host;
# need to use 1.1 to support chunked transfers
@@ -61,10 +78,9 @@ server {
# misc headers
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-Forwarded-Proto $forwarded_proto;
proxy_set_header X-Forwarded-Host $forwarded_host;
proxy_set_header X-Forwarded-Port $forwarded_port;
proxy_set_header Host $host;
proxy_http_version 1.1;

View File

@@ -4,6 +4,24 @@ log_format custom_main '$remote_addr - $remote_user [$time_local] "$request" '
'"$http_user_agent" "$http_x_forwarded_for" '
'rt=$request_time';
# Map X-Forwarded-Proto or fallback to $scheme
map $http_x_forwarded_proto $forwarded_proto {
default $http_x_forwarded_proto;
"" $scheme;
}
# Map X-Forwarded-Host or fallback to $host
map $http_x_forwarded_host $forwarded_host {
default $http_x_forwarded_host;
"" $host;
}
# Map X-Forwarded-Port or fallback to server port
map $http_x_forwarded_port $forwarded_port {
default $http_x_forwarded_port;
"" $server_port;
}
upstream api_server {
# fail_timeout=0 means we always retry an upstream even if it failed
# to return a good HTTP response
@@ -41,10 +59,9 @@ server {
# misc headers
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-Forwarded-Proto $forwarded_proto;
proxy_set_header X-Forwarded-Host $forwarded_host;
proxy_set_header X-Forwarded-Port $forwarded_port;
proxy_set_header Host $host;
# need to use 1.1 to support chunked transfers
@@ -66,10 +83,9 @@ server {
# misc headers
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-Forwarded-Proto $forwarded_proto;
proxy_set_header X-Forwarded-Host $forwarded_host;
proxy_set_header X-Forwarded-Port $forwarded_port;
proxy_set_header Host $host;
proxy_http_version 1.1;

View File

@@ -17,8 +17,5 @@ dependencies:
- name: minio
repository: https://charts.min.io/
version: 5.4.0
- name: code-interpreter
repository: https://onyx-dot-app.github.io/python-sandbox/
version: 0.2.1
digest: sha256:aedc211d9732c934be8b79735b62f8caa9bcd235e03fd0dd10b49e0a13ed15b7
generated: "2026-02-20T11:19:47.957449-08:00"
digest: sha256:e3e3df4464d00165d63e5aa150b768c0957e5eab2b310414ce5d7381d00dbd2e
generated: "2026-02-19T20:17:06.462195-08:00"

View File

@@ -44,7 +44,3 @@ dependencies:
version: 5.4.0
repository: https://charts.min.io/
condition: minio.enabled
- name: code-interpreter
version: 0.2.1
repository: https://onyx-dot-app.github.io/python-sandbox/
condition: codeInterpreter.enabled

View File

@@ -144,7 +144,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.6.0",
"onyx-devtools==0.5.7",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",

18
uv.lock generated
View File

@@ -4711,7 +4711,7 @@ requires-dist = [
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.0" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.5.7" },
{ name = "openai", specifier = "==2.14.0" },
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
@@ -4816,20 +4816,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
[[package]]
name = "onyx-devtools"
version = "0.6.0"
version = "0.5.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fastapi" },
{ name = "openapi-generator-cli" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/f9/79d66c1f06e4d1dca0a9df30afcd65ec1a69219fdf17c45349396d1ec668/onyx_devtools-0.6.0-py3-none-any.whl", hash = "sha256:26049075a6d3eb794f44c1bbe55a7cfc0c5427de681ed29319064e2deb956a15", size = 3777572, upload-time = "2026-02-19T23:05:51.823Z" },
{ url = "https://files.pythonhosted.org/packages/40/37/0abff5ab8d79c90f9d57eeaf4998f668145b01e81da0307df56c3b15d16c/onyx_devtools-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a7c00f2f1924c231b2480edcd3b6aa83398e13e4587c213fe1c97e0f6d3cfce1", size = 3822965, upload-time = "2026-02-19T23:06:02.992Z" },
{ url = "https://files.pythonhosted.org/packages/59/79/a8c23e456b7f1bb4cb741875af6c323fba11d5ef1ba121ea8b44587c236f/onyx_devtools-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0e67fc47dfffb510826a6487dd5029a65b4a5b3f8a42e0e1208b6faee353518c", size = 3570391, upload-time = "2026-02-19T23:05:48.853Z" },
{ url = "https://files.pythonhosted.org/packages/c5/c5/d166bf2c98b80fd83d76abe88e57d63a8cb55880ba40a3d34c831361e3cf/onyx_devtools-0.6.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0fdbd085f82788b900620424798d04dc1b10c3b1baf9be821ac178adc41c6858", size = 3432611, upload-time = "2026-02-19T23:05:51.924Z" },
{ url = "https://files.pythonhosted.org/packages/18/8e/c53fb7f7781acbf37ca80ebcee5d1274d54c6d853606adefc517df715f9a/onyx_devtools-0.6.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:3915ad5ea245e597a8ad91bd2ba5efc2b6a336ca59c7f3670bd89530cc9ab00f", size = 3777586, upload-time = "2026-02-19T23:05:51.877Z" },
{ url = "https://files.pythonhosted.org/packages/e5/57/194ded4aa5151d96911b021829e015370b4f1fc7493ac584d445fd96f97b/onyx_devtools-0.6.0-py3-none-win_amd64.whl", hash = "sha256:478cdae03ae2e797345396397318446622c7472df0a7d9dbd58d3e96489198b2", size = 3871835, upload-time = "2026-02-19T23:05:51.209Z" },
{ url = "https://files.pythonhosted.org/packages/3c/e9/cc7d204b9b1103b2f33f8f62d29076083f40f44697b398e83b3d44daca23/onyx_devtools-0.6.0-py3-none-win_arm64.whl", hash = "sha256:4bff060fd5f017ddceaf753252e0bc16699922d9a0a88506a56505aad4580824", size = 3492854, upload-time = "2026-02-19T23:05:51.856Z" },
{ url = "https://files.pythonhosted.org/packages/23/7d/a9135044e220b6ef6a0752be826c6c758a1fc8b59d545306938aa43e8976/onyx_devtools-0.5.7-py3-none-any.whl", hash = "sha256:47c5cdefb525523a9860ed134366f30a0d2ad30e055b2350c1da577d1059654b", size = 3769892, upload-time = "2026-02-12T20:06:02.937Z" },
{ url = "https://files.pythonhosted.org/packages/e7/63/26dbfc35f62d0617e4c46b508e106f155990c37c851d8eb44bc331b2e933/onyx_devtools-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c7ce707d9e27733e7300b2be3686e3fd76d62b9b1c20c9bd02dac707f4eac1d5", size = 3815888, upload-time = "2026-02-12T20:06:07.024Z" },
{ url = "https://files.pythonhosted.org/packages/82/55/4498e74af5f115355127c966e326f9ae430460170d1f1d50c2f150f53a00/onyx_devtools-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0d02a0c1c48a33bd85b251a2288d94a00effc2139b6e2b7018362cba8cf717e1", size = 3562190, upload-time = "2026-02-12T20:06:00.998Z" },
{ url = "https://files.pythonhosted.org/packages/18/70/fc1490420bd690bc6b3ebc3a6da68347636cb1a31afa07801fba9f77def4/onyx_devtools-0.5.7-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:fe3ae04f06e1b421f1297e70d2c14013d85941afa85210bfd96db30abb391989", size = 3425118, upload-time = "2026-02-12T20:05:59.192Z" },
{ url = "https://files.pythonhosted.org/packages/b3/46/76b44234d7cd4cf5c73b897f6dd1864c867c63cc871fd73f8901592c9248/onyx_devtools-0.5.7-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:ecf5f525c773d8db0b58bef3a02b00df31e7a9ade16213b4220eb2baffffd8e2", size = 3769913, upload-time = "2026-02-12T20:06:03.405Z" },
{ url = "https://files.pythonhosted.org/packages/7b/e5/9ef8d3265dfc82dbd9d27653d981ccb67c779882807ef1bd7fcecbe1c68a/onyx_devtools-0.5.7-py3-none-win_amd64.whl", hash = "sha256:f84368da19311acc246d511c5b2874b14ca1c9e53675198ba6ccabefbe57d648", size = 3863558, upload-time = "2026-02-12T20:06:01.995Z" },
{ url = "https://files.pythonhosted.org/packages/30/ad/f23ace3e049017e9cfcc06302005fd476b44357b6f4ade521febd8393599/onyx_devtools-0.5.7-py3-none-win_arm64.whl", hash = "sha256:edb1dcd3901f7532114d40fbc903ba60c528bdad397425c174dc5841b5b8de43", size = 3486869, upload-time = "2026-02-12T20:06:03.719Z" },
]
[[package]]

4
web/package-lock.json generated
View File

@@ -36,7 +36,7 @@
"@radix-ui/react-slider": "^1.2.2",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.2.8",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^10.27.0",
"@sentry/tracing": "^7.120.3",
"@stripe/stripe-js": "^4.6.0",
@@ -3788,8 +3788,6 @@
},
"node_modules/@radix-ui/react-tooltip": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
"integrity": "sha512-tY7sVt1yL9ozIxvmbtN5qtmH2krXcBCfjEiCgKGLqunJHvgvZG2Pcl2oQ3kbcZARb1BGEHdkLzcYGO8ynVlieg==",
"license": "MIT",
"dependencies": {
"@radix-ui/primitive": "1.1.3",

View File

@@ -51,7 +51,7 @@
"@radix-ui/react-slider": "^1.2.2",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.2.8",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^10.27.0",
"@sentry/tracing": "^7.120.3",
"@stripe/stripe-js": "^4.6.0",

View File

@@ -115,7 +115,9 @@ function DefaultAssistantConfig() {
const enabledToolsMap: { [key: number]: boolean } = {};
tools.forEach((tool) => {
enabledToolsMap[tool.id] = config.tool_ids.includes(tool.id);
// Enable tool if it's in the current config OR if it's marked as default_enabled
enabledToolsMap[tool.id] =
config.tool_ids.includes(tool.id) || tool.default_enabled;
});
return (

View File

@@ -1,12 +0,0 @@
import { cn } from "@/lib/utils";
export function BlinkingBar({ addMargin = false }: { addMargin?: boolean }) {
return (
<span
className={cn(
"animate-pulse flex-none bg-theme-primary-05 relative top-[0.25rem] inline-block w-[0.5em] h-[1.25em]",
addMargin && "ml-1"
)}
></span>
);
}

View File

@@ -0,0 +1,11 @@
import React from "react";
export function BlinkingDot({ addMargin = false }: { addMargin?: boolean }) {
return (
<span
className={`animate-pulse flex-none bg-theme-primary-05 inline-block rounded-full h-3 w-3 ${
addMargin ? "ml-2" : ""
}`}
/>
);
}

View File

@@ -13,7 +13,7 @@ import { WebResultIcon } from "@/components/WebResultIcon";
import { SubQuestionDetail, CitationMap } from "../interfaces";
import { ValidSources } from "@/lib/types";
import { ProjectFile } from "../projects/projectsService";
import { BlinkingBar } from "./BlinkingBar";
import { BlinkingDot } from "./BlinkingDot";
import Text from "@/refresh-components/texts/Text";
import SourceTag from "@/refresh-components/buttons/source-tag/SourceTag";
import {
@@ -157,7 +157,7 @@ export const MemoizedLink = memo(
}, [document, updatePresentingDocument, question, openQuestion]);
if (value?.toString().startsWith("*")) {
return <BlinkingBar addMargin />;
return <BlinkingDot addMargin />;
} else if (value?.toString().startsWith("[")) {
const sourceInfo = documentSourceInfo || questionSourceInfo;
if (!sourceInfo) {

View File

@@ -9,7 +9,7 @@ import {
import { MessageRenderer, FullChatState } from "../interfaces";
import { isFinalAnswerComplete } from "../../../services/packetUtils";
import { useMarkdownRenderer } from "../markdownUtils";
import { BlinkingBar } from "../../BlinkingBar";
import { BlinkingDot } from "../../BlinkingDot";
// Control the rate of packet streaming (packets per second)
const PACKET_DELAY_MS = 10;
@@ -138,7 +138,7 @@ export const MessageTextRenderer: MessageRenderer<
)}
</>
) : (
<BlinkingBar addMargin />
<BlinkingDot addMargin />
),
},
]);

View File

@@ -3,7 +3,7 @@ import {
MessageRenderer,
RenderType,
} from "@/app/app/message/messageComponents/interfaces";
import { BlinkingBar } from "@/app/app/message/BlinkingBar";
import { BlinkingDot } from "@/app/app/message/BlinkingDot";
import { OnyxDocument } from "@/lib/search/interfaces";
import { ValidSources } from "@/lib/types";
import { SearchChipList, SourceInfo } from "../search/SearchChipList";
@@ -97,7 +97,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
onClick={(doc: OnyxDocument) => {
if (doc.link) window.open(doc.link, "_blank");
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
) : displayUrls ? (
<SearchChipList
@@ -107,10 +107,10 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
getKey={(url: string) => url}
toSourceInfo={urlToSourceInfo}
onClick={(url: string) => window.open(url, "_blank")}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
) : (
!stopPacketSeen && <BlinkingBar />
!stopPacketSeen && <BlinkingDot />
)}
</div>
),
@@ -136,7 +136,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
onClick={(doc: OnyxDocument) => {
if (doc.link) window.open(doc.link, "_blank");
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
) : displayUrls ? (
<SearchChipList
@@ -146,11 +146,11 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
getKey={(url: string) => url}
toSourceInfo={urlToSourceInfo}
onClick={(url: string) => window.open(url, "_blank")}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
) : (
<div className="flex flex-wrap gap-x-2 gap-y-2 ml-1">
{!stopPacketSeen && <BlinkingBar />}
{!stopPacketSeen && <BlinkingDot />}
</div>
)}
</div>

View File

@@ -9,7 +9,7 @@ import {
MessageRenderer,
RenderType,
} from "@/app/app/message/messageComponents/interfaces";
import { BlinkingBar } from "@/app/app/message/BlinkingBar";
import { BlinkingDot } from "@/app/app/message/BlinkingDot";
import { Section } from "@/layouts/general-layouts";
import Card from "@/refresh-components/cards/Card";
import Text from "@/refresh-components/texts/Text";
@@ -142,7 +142,7 @@ export const FileReaderToolRenderer: MessageRenderer<
)}
</>
) : (
!stopPacketSeen && <BlinkingBar />
!stopPacketSeen && <BlinkingDot />
)}
</Section>
),

View File

@@ -5,7 +5,7 @@ import {
MessageRenderer,
RenderType,
} from "@/app/app/message/messageComponents/interfaces";
import { BlinkingBar } from "@/app/app/message/BlinkingBar";
import { BlinkingDot } from "@/app/app/message/BlinkingDot";
import { constructCurrentMemoryState } from "./memoryStateUtils";
import Text from "@/refresh-components/texts/Text";
import { SvgEditBig, SvgMaximize2 } from "@opal/icons";
@@ -18,7 +18,7 @@ import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
* MemoryToolRenderer - Renders memory tool execution steps
*
* States:
* - Loading (start, no delta): "Saving memory..." with BlinkingBar
* - Loading (start, no delta): "Saving memory..." with BlinkingDot
* - Delta received: operation label + memory text
* - Complete (SectionEnd): "Memory saved" / "Memory updated" + memory text
* - No Access: "Memory tool disabled"
@@ -126,7 +126,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
</div>
</div>
) : (
!stopPacketSeen && <BlinkingBar />
!stopPacketSeen && <BlinkingDot />
)}
</div>
);

View File

@@ -4,7 +4,7 @@ import {
MessageRenderer,
RenderType,
} from "@/app/app/message/messageComponents/interfaces";
import { BlinkingBar } from "@/app/app/message/BlinkingBar";
import { BlinkingDot } from "@/app/app/message/BlinkingDot";
import { OnyxDocument } from "@/lib/search/interfaces";
import { ValidSources } from "@/lib/types";
import { SearchChipList, SourceInfo } from "./SearchChipList";
@@ -109,7 +109,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
</div>
),
@@ -134,7 +134,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
expansionCount={QUERIES_PER_EXPANSION}
getKey={(_, index) => index}
toSourceInfo={queryToSourceInfo}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
showDetailsCard={false}
isQuery={true}
/>
@@ -164,7 +164,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
),
},
@@ -187,7 +187,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
expansionCount={QUERIES_PER_EXPANSION}
getKey={(_, index) => index}
toSourceInfo={queryToSourceInfo}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
showDetailsCard={false}
isQuery={true}
/>
@@ -213,7 +213,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
/>
</>
)}

View File

@@ -5,7 +5,7 @@ import {
MessageRenderer,
RenderType,
} from "@/app/app/message/messageComponents/interfaces";
import { BlinkingBar } from "@/app/app/message/BlinkingBar";
import { BlinkingDot } from "@/app/app/message/BlinkingDot";
import { ValidSources } from "@/lib/types";
import { SearchChipList, SourceInfo } from "./SearchChipList";
import {
@@ -80,7 +80,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
expansionCount={QUERIES_PER_EXPANSION}
getKey={(_, index) => index}
toSourceInfo={queryToSourceInfo}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
showDetailsCard={false}
isQuery={true}
/>
@@ -105,7 +105,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
expansionCount={QUERIES_PER_EXPANSION}
getKey={(_, index) => index}
toSourceInfo={queryToSourceInfo}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
showDetailsCard={false}
isQuery={true}
/>
@@ -126,7 +126,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
expansionCount={QUERIES_PER_EXPANSION}
getKey={(_, index) => index}
toSourceInfo={queryToSourceInfo}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={!stopPacketSeen ? <BlinkingDot /> : undefined}
showDetailsCard={false}
isQuery={true}
/>

View File

@@ -17,7 +17,16 @@ async function handleSamlCallback(
const fetchOptions: RequestInit = {
method,
headers: {},
headers: {
"X-Forwarded-Host":
request.headers.get("X-Forwarded-Host") ||
request.headers.get("host") ||
"",
"X-Forwarded-Port":
request.headers.get("X-Forwarded-Port") ||
new URL(request.url).port ||
"",
},
};
let relayState: string | null = null;

View File

@@ -186,12 +186,7 @@ export const DefaultDropdown = forwardRef<HTMLDivElement, DefaultDropdownProps>(
<FiChevronDown className="my-auto ml-auto" />
</div>
</Popover.Trigger>
<Popover.Content
align="start"
side={side}
sideOffset={5}
width="trigger"
>
<Popover.Content align="start" side={side} sideOffset={5}>
<div
ref={ref}
className={`

View File

@@ -27,8 +27,7 @@ import { useState, useEffect, memo, JSX } from "react";
import remarkGfm from "remark-gfm";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { Section } from "@/layouts/general-layouts";
import { cn, transformLinkUri } from "@/lib/utils";
import { transformLinkUri } from "@/lib/utils";
import FileInput from "@/app/admin/connectors/[connector]/pages/ConnectorInput/FileInput";
import InputDatePicker from "@/refresh-components/inputs/InputDatePicker";
import { RichTextSubtext } from "./RichTextSubtext";
@@ -694,7 +693,6 @@ interface BooleanFormFieldProps {
optional?: boolean;
tooltip?: string;
disabledTooltip?: string;
disabledTooltipSide?: "top" | "bottom" | "left" | "right";
onChange?: (checked: boolean) => void;
}
@@ -709,7 +707,6 @@ export const BooleanFormField = memo(function BooleanFormField({
disabled,
tooltip,
disabledTooltip,
disabledTooltipSide,
onChange,
}: BooleanFormFieldProps) {
// Generate a stable, valid id from the field name for label association
@@ -717,69 +714,48 @@ export const BooleanFormField = memo(function BooleanFormField({
return (
<div>
<FastField
name={name}
type="checkbox"
disabled={disabled}
shouldUpdate={(next: any, prev: any) =>
next.disabled !== prev.disabled ||
next.formik.values !== prev.formik.values
}
>
{({ field, form }: any) => {
const toggle = () => {
if (!disabled) {
const newValue = !field.value;
form.setFieldValue(name, newValue);
if (onChange) onChange(newValue);
}
};
return (
<div className="flex items-center text-sm">
<FastField name={name} type="checkbox">
{({ field, form }: any) => (
<SimpleTooltip
// This may seem confusing, but we only want to show the `disabledTooltip` if and only if the `BooleanFormField` is disabled.
// If it disabled, then we "enable" the showing of the tooltip. Thus, `disabled={!disabled}` is not a mistake.
disabled={!disabled}
tooltip={disabledTooltip}
side={disabledTooltipSide}
>
<Section flexDirection="row" width="fit" height="fit" gap={0}>
<Checkbox
aria-label={`${label
.toLowerCase()
.replace(" ", "-")}-checkbox`}
id={checkboxId}
className={cn(
disabled && "opacity-50",
removeIndent ? "mr-2" : "mx-3"
)}
checked={Boolean(field.value)}
onCheckedChange={(checked) => {
if (!disabled) {
form.setFieldValue(name, checked === true);
if (onChange) onChange(checked === true);
}
}}
/>
{!noLabel && (
<div
className={disabled ? "" : "cursor-pointer"}
onClick={toggle}
>
<div className="flex items-center gap-x-2">
<Label small={small}>{`${label}${
optional ? " (Optional)" : ""
}`}</Label>
{tooltip && <ToolTipDetails>{tooltip}</ToolTipDetails>}
</div>
{subtext && <SubLabel>{subtext}</SubLabel>}
</div>
)}
</Section>
<Checkbox
aria-label={`${label.toLowerCase().replace(" ", "-")}-checkbox`}
id={checkboxId}
className={`
${disabled ? "opacity-50" : ""}
${removeIndent ? "mr-2" : "mx-3"}`}
checked={Boolean(field.value)}
onCheckedChange={(checked) => {
if (!disabled) form.setFieldValue(name, checked === true);
if (onChange) onChange(checked === true);
}}
/>
</SimpleTooltip>
);
}}
</FastField>
)}
</FastField>
{!noLabel && (
<div>
<div className="flex items-center gap-x-2">
<Label
htmlFor={checkboxId}
small={small}
className="cursor-pointer"
>{`${label}${optional ? " (Optional)" : ""}`}</Label>
{tooltip && <ToolTipDetails>{tooltip}</ToolTipDetails>}
</div>
{subtext && (
<label htmlFor={checkboxId} className="cursor-pointer">
<SubLabel>{subtext}</SubLabel>
</label>
)}
</div>
)}
</div>
<ErrorMessage
name={name}

View File

@@ -183,8 +183,36 @@ export function ToolSelector({
valid base URL.
</div>
<div>
<span className="font-semibold">Open URL:</span> Open and read
the content of URLs provided in the conversation.
<div>
<span className="font-semibold">Open URL:</span> Open and read
the content of URLs provided in the conversation.
</div>
{openUrlTool && setFieldValue && (
<label className="flex items-center gap-2 cursor-pointer mt-1.5 ml-1">
<input
type="checkbox"
checked={enabledToolsMap[openUrlTool.id] || false}
onChange={(e) => {
if (!isOpenUrlForced) {
setFieldValue(
`enabled_tools_map.${openUrlTool.id}`,
e.target.checked
);
}
}}
disabled={isOpenUrlForced}
className="h-3.5 w-3.5 rounded border-border-medium disabled:opacity-50 disabled:cursor-not-allowed"
/>
<span className="text-xs">
Enable Open URL
{isOpenUrlForced && (
<span className="text-text-500 ml-1">
(required for Web Search)
</span>
)}
</span>
</label>
)}
</div>
</div>
}
@@ -198,7 +226,6 @@ export function ToolSelector({
subtext="Search through your organization's knowledge base and documents"
disabled={searchToolDisabled}
disabledTooltip={searchToolDisabledTooltip}
disabledTooltipSide="bottom"
/>
)}
@@ -216,17 +243,6 @@ export function ToolSelector({
/>
)}
{openUrlTool && setFieldValue && (
<BooleanFormField
name={`enabled_tools_map.${openUrlTool.id}`}
label="Open URL"
subtext="Open and read the content of URLs provided in the conversation"
disabled={isOpenUrlForced}
disabledTooltip="Required for Web Search"
disabledTooltipSide="bottom"
/>
)}
{imageGenerationTool && (
<BooleanFormField
name={`enabled_tools_map.${imageGenerationTool.id}`}
@@ -234,7 +250,6 @@ export function ToolSelector({
subtext="Generate and manipulate images using AI-powered tools."
disabled={imageGenerationDisabled}
disabledTooltip={imageGenerationDisabledTooltip}
disabledTooltipSide="bottom"
/>
)}

View File

@@ -1,8 +1,5 @@
"use client";
import {
WellKnownLLMProviderDescriptor,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import React, {
createContext,
useContext,
@@ -11,101 +8,57 @@ import React, {
useCallback,
} from "react";
import { useUser } from "@/providers/UserProvider";
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llm/svc";
import { useRouter } from "next/navigation";
import { checkLlmProvider } from "../initialSetup/welcome/lib";
interface ProviderContextType {
shouldShowConfigurationNeeded: boolean;
providerOptions: WellKnownLLMProviderDescriptor[];
refreshProviderInfo: () => Promise<void>;
// Expose configured provider instances for components that need it (e.g., onboarding)
llmProviders: LLMProviderDescriptor[] | undefined;
isLoadingProviders: boolean;
hasProviders: boolean;
refreshProviderInfo: () => Promise<void>; // Add this line
}
const ProviderContext = createContext<ProviderContextType | undefined>(
undefined
);
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
function checkDefaultLLMProviderTestComplete() {
if (typeof window === "undefined") return true;
return (
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
);
}
function setDefaultLLMProviderTestComplete() {
if (typeof window === "undefined") return;
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
}
export function ProviderContextProvider({
children,
}: {
children: React.ReactNode;
}) {
const { user } = useUser();
const router = useRouter();
// Use SWR hooks instead of raw fetch
const {
llmProviders,
isLoading: isLoadingProviders,
refetch: refetchProviders,
} = useLLMProviders();
const { llmProviderOptions: providerOptions, refetch: refetchOptions } =
useLLMProviderOptions();
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
const [providerOptions, setProviderOptions] = useState<
WellKnownLLMProviderDescriptor[]
>([]);
const [defaultCheckSuccessful, setDefaultCheckSuccessful] =
useState<boolean>(true);
const fetchProviderInfo = useCallback(async () => {
const { providers, options, defaultCheckSuccessful } =
await checkLlmProvider(user);
// Test the default provider - only runs if test hasn't passed yet
const testDefaultProvider = useCallback(async () => {
const shouldCheck =
!checkDefaultLLMProviderTestComplete() &&
(!user || user.role === "admin");
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
setProviderOptions(options);
}, [user, setValidProviderExists, setProviderOptions]);
if (shouldCheck) {
const success = await testDefaultProviderSvc();
setDefaultCheckSuccessful(success);
if (success) {
setDefaultLLMProviderTestComplete();
}
}
}, [user]);
// Test default provider on mount
useEffect(() => {
testDefaultProvider();
}, [testDefaultProvider]);
const hasProviders = (llmProviders?.length ?? 0) > 0;
const validProviderExists = hasProviders && defaultCheckSuccessful;
fetchProviderInfo();
}, [router, user, fetchProviderInfo]);
const shouldShowConfigurationNeeded =
!validProviderExists && (providerOptions?.length ?? 0) > 0;
!validProviderExists && providerOptions.length > 0;
const refreshProviderInfo = useCallback(async () => {
// Refetch provider lists and re-test default provider if needed
await Promise.all([
refetchProviders(),
refetchOptions(),
testDefaultProvider(),
]);
}, [refetchProviders, refetchOptions, testDefaultProvider]);
const refreshProviderInfo = async () => {
await fetchProviderInfo();
};
return (
<ProviderContext.Provider
value={{
shouldShowConfigurationNeeded,
providerOptions: providerOptions ?? [],
refreshProviderInfo,
llmProviders,
isLoadingProviders,
hasProviders,
providerOptions,
refreshProviderInfo, // Add this line
}}
>
{children}

View File

@@ -0,0 +1 @@
export const COMPLETED_WELCOME_FLOW_COOKIE = "completed_welcome_flow";

View File

@@ -0,0 +1,56 @@
import {
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { User } from "@/lib/types";
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
function checkDefaultLLMProviderTestComplete() {
return (
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
);
}
function setDefaultLLMProviderTestComplete() {
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
}
function shouldCheckDefaultLLMProvider(user: User | null) {
return (
!checkDefaultLLMProviderTestComplete() && (!user || user.role === "admin")
);
}
export async function checkLlmProvider(user: User | null) {
/* NOTE: should only be called on the client side, after initial render */
const checkDefault = shouldCheckDefaultLLMProvider(user);
const tasks = [
fetch("/api/llm/provider"),
fetch("/api/admin/llm/built-in/options"),
checkDefault
? fetch("/api/admin/llm/test/default", { method: "POST" })
: (async () => null)(),
];
const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks);
let providers: LLMProviderView[] = [];
if (providerResponse?.ok) {
providers = await providerResponse.json();
}
let options: WellKnownLLMProviderDescriptor[] = [];
if (optionsResponse?.ok) {
options = await optionsResponse.json();
}
let defaultCheckSuccessful =
!checkDefault || defaultCheckResponse?.ok || false;
if (defaultCheckSuccessful) {
setDefaultLLMProviderTestComplete();
}
return { providers, options, defaultCheckSuccessful };
}

View File

@@ -220,7 +220,7 @@ function SettingsHeader({
className={cn(
"w-full bg-background-tint-01",
isSticky && "sticky top-0 z-settings-header",
backButton && "md:pt-4"
backButton ? "md:pt-4" : "md:pt-10"
)}
>
{backButton && (

View File

@@ -1,23 +0,0 @@
const CHAT_FILE_PREFIX = "/api/chat/file";
/**
* Fetch a chat file by its ID, returning the raw Response.
*
* The caller is responsible for consuming the body (e.g. `.blob()`,
* `.text()`) since different consumers need different formats.
*/
export async function fetchChatFile(fileId: string): Promise<Response> {
const response = await fetch(
`${CHAT_FILE_PREFIX}/${encodeURIComponent(fileId)}`,
{
method: "GET",
cache: "force-cache",
}
);
if (!response.ok) {
throw new Error("Failed to load document.");
}
return response;
}

View File

@@ -839,42 +839,6 @@ export const connectorConfigs: Record<
description:
"Index aspx-pages of all SharePoint sites defined above, even if a library or folder is specified.",
},
{
type: "text",
query: "Microsoft Authority Host:",
label: "Authority Host",
name: "authority_host",
optional: true,
default: "https://login.microsoftonline.com",
description:
"The Microsoft identity authority host used for authentication. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://login.microsoftonline.us",
},
{
type: "text",
query: "Microsoft Graph API Host:",
label: "Graph API Host",
name: "graph_api_host",
optional: true,
default: "https://graph.microsoft.com",
description:
"The Microsoft Graph API host. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://graph.microsoft.us",
},
{
type: "text",
query: "SharePoint Domain Suffix:",
label: "SharePoint Domain Suffix",
name: "sharepoint_domain_suffix",
optional: true,
default: "sharepoint.com",
description:
"The domain suffix for SharePoint sites (e.g. sharepoint.com). " +
"For most deployments, leave as default. " +
"For GCC High, use sharepoint.us",
},
],
},
teams: {
@@ -889,32 +853,7 @@ export const connectorConfigs: Record<
description: `Specify 0 or more Teams to index. For example, specifying the Team 'Support' for the 'onyxai' Org will cause us to only index messages sent in channels belonging to the 'Support' Team. If no Teams are specified, all Teams in your organization will be indexed.`,
},
],
advanced_values: [
{
type: "text",
query: "Microsoft Authority Host:",
label: "Authority Host",
name: "authority_host",
optional: true,
default: "https://login.microsoftonline.com",
description:
"The Microsoft identity authority host used for authentication. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://login.microsoftonline.us",
},
{
type: "text",
query: "Microsoft Graph API Host:",
label: "Graph API Host",
name: "graph_api_host",
optional: true,
default: "https://graph.microsoft.com",
description:
"The Microsoft Graph API host. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://graph.microsoft.us",
},
],
advanced_values: [],
},
discourse: {
description: "Configure Discourse connector",
@@ -1942,15 +1881,10 @@ export interface SharepointConfig {
sites?: string[];
include_site_pages?: boolean;
include_site_documents?: boolean;
authority_host?: string;
graph_api_host?: string;
sharepoint_domain_suffix?: string;
}
export interface TeamsConfig {
teams?: string[];
authority_host?: string;
graph_api_host?: string;
}
export interface DiscourseConfig {
@@ -1969,6 +1903,10 @@ export interface DrupalWikiConfig {
include_attachments?: boolean;
}
export interface TeamsConfig {
teams?: string[];
}
export interface ProductboardConfig {}
export interface SlackConfig {

View File

@@ -1,19 +0,0 @@
import useSWR from "swr";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
export function useLLMProviderOptions() {
const { data, error, mutate } = useSWR<
WellKnownLLMProviderDescriptor[] | undefined
>("/api/admin/llm/built-in/options", errorHandlingFetcher, {
revalidateOnFocus: false,
dedupingInterval: 60000, // Dedupe requests within 1 minute
});
return {
llmProviderOptions: data,
isLoading: !error && !data,
error,
refetch: mutate,
};
}

View File

@@ -1,40 +0,0 @@
import * as languages from "linguist-languages";
interface LinguistLanguage {
name: string;
type: string;
extensions?: string[];
filenames?: string[];
}
// Build extension → language name and filename → language name maps at module load
const extensionMap = new Map<string, string>();
const filenameMap = new Map<string, string>();
for (const lang of Object.values(languages) as LinguistLanguage[]) {
if (lang.type !== "programming") continue;
const name = lang.name.toLowerCase();
for (const ext of lang.extensions ?? []) {
// First language to claim an extension wins
if (!extensionMap.has(ext)) {
extensionMap.set(ext, name);
}
}
for (const filename of lang.filenames ?? []) {
if (!filenameMap.has(filename.toLowerCase())) {
filenameMap.set(filename.toLowerCase(), name);
}
}
}
/**
* Returns the language name for a given file name, or null if it's not a
* recognised code file. Looks up by extension first, then by exact filename
* (e.g. "Dockerfile", "Makefile"). Runs in O(1).
*/
export function getCodeLanguage(name: string): string | null {
const lower = name.toLowerCase();
const ext = lower.match(/\.[^.]+$/)?.[0];
return (ext && extensionMap.get(ext)) ?? filenameMap.get(lower) ?? null;
}

View File

@@ -1,23 +0,0 @@
/**
* LLM action functions for mutations.
*
* These are async functions for one-off actions that don't need SWR caching.
*
* Endpoints:
* - /api/admin/llm/test/default - Test the default LLM provider connection
*/
/**
* Test the default LLM provider.
* Returns true if the default provider is configured and working, false otherwise.
*/
export async function testDefaultProvider(): Promise<boolean> {
try {
const response = await fetch("/api/admin/llm/test/default", {
method: "POST",
});
return response?.ok || false;
} catch {
return false;
}
}

Some files were not shown because too many files have changed in this diff Show More