Compare commits

..

2 Commits

89 changed files with 1575 additions and 4235 deletions

View File

@@ -9,8 +9,7 @@ inputs:
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: false
default: ""
required: true
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
@@ -18,14 +17,6 @@ inputs:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
api-version:
description: "Optional NIGHTLY_LLM_API_VERSION"
required: false
default: ""
deployment-name:
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
@@ -68,7 +59,6 @@ runs:
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AWS_REGION_NAME=us-west-2
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
@@ -92,8 +82,6 @@ runs:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
@@ -103,6 +91,11 @@ runs:
max_attempts: 2
retry_wait_seconds: 10
command: |
if [ -z "${MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -117,13 +110,10 @@ runs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e AWS_REGION_NAME=us-west-2 \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \

View File

@@ -0,0 +1,44 @@
name: Nightly LLM Provider Chat Tests (OpenAI)
concurrency:
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
openai-provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
provider: openai
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
strict: true
secrets:
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [openai-provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: openai-provider-chat-test
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -1,56 +0,0 @@
name: Nightly LLM Provider Chat Tests
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
azure_api_key: ${{ secrets.AZURE_API_KEY }}
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -89,10 +89,6 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}

View File

@@ -3,66 +3,33 @@ name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
provider:
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
required: true
type: string
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
type: string
bedrock_models:
description: "Comma-separated models for bedrock"
required: false
default: ""
type: string
vertex_ai_models:
description: "Comma-separated models for vertex_ai"
required: false
default: ""
type: string
azure_models:
description: "Comma-separated models for azure"
required: false
default: ""
type: string
ollama_models:
description: "Comma-separated models for ollama_chat"
required: false
default: ""
type: string
openrouter_models:
description: "Comma-separated models for openrouter"
required: false
default: ""
type: string
azure_api_base:
description: "API base for azure provider"
required: false
default: ""
models:
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
required: true
type: string
strict:
description: "Default NIGHTLY_LLM_STRICT passed to tests"
description: "Pass-through value for NIGHTLY_LLM_STRICT"
required: false
default: true
type: boolean
api_base:
description: "Optional NIGHTLY_LLM_API_BASE override"
required: false
default: ""
type: string
custom_config_json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
required: false
default: ""
type: string
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
azure_api_key:
required: false
ollama_api_key:
required: false
openrouter_api_key:
required: false
provider_api_key:
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
required: true
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
@@ -71,8 +38,29 @@ on:
permissions:
contents: read
env:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
jobs:
validate-inputs:
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Validate required nightly provider inputs
run: |
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
build-backend-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -102,6 +90,7 @@ jobs:
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -130,6 +119,7 @@ jobs:
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -159,75 +149,11 @@ jobs:
provider-chat-test:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_secret: bedrock_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_secret: azure_api_key
custom_config_secret: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_secret: ollama_api_key
custom_config_secret: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_secret: openrouter_api_key
custom_config_secret: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
required: false
[build-backend-image, build-model-server-image, build-integration-image]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
@@ -241,14 +167,12 @@ jobs:
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
models: ${{ env.NIGHTLY_LLM_MODELS }}
provider-api-key: ${{ secrets.provider_api_key }}
strict: ${{ env.NIGHTLY_LLM_STRICT }}
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
@@ -270,7 +194,7 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log

View File

@@ -322,7 +322,6 @@ def list_users(
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -366,7 +365,6 @@ def get_user(
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
@@ -723,7 +721,6 @@ def list_groups(
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -760,7 +757,6 @@ def get_group(
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):

View File

@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)

View File

@@ -3,6 +3,7 @@ import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
@@ -162,11 +163,13 @@ class ChatStateContainer:
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
func: Callable[..., None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
*args: Any,
**kwargs: Any,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
@@ -177,18 +180,19 @@ def run_chat_loop_with_state_containers(
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
*args: Additional positional arguments for func
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
arg1, arg2, kwarg1=value1
)
for packet in packets:
# Process packets
@@ -197,7 +201,9 @@ def run_chat_loop_with_state_containers(
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
# Ensure state_container is passed explicitly, removing it from kwargs if present
kwargs_with_state = {**kwargs, "state_container": state_container}
func(emitter, *args, **kwargs_with_state)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(

View File

@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
context_image_files: list[ChatLoadedFile],
project_image_files: list[ChatLoadedFile],
additional_context: str | None,
token_counter: Callable[[str], int],
tool_id_to_name_map: dict[int, str],
@@ -541,11 +541,11 @@ def convert_chat_history(
)
# Add the user message with image files attached
# If this is the last USER message, also include context_image_files
# Note: context image file tokens are NOT counted in the token count
# If this is the last USER message, also include project_image_files
# Note: project image file tokens are NOT counted in the token count
if idx == last_user_message_idx:
if context_image_files:
image_files.extend(context_image_files)
if project_image_files:
image_files.extend(project_image_files)
if additional_context:
simple_messages.append(

View File

@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.chat.prompt_utils import build_reminder_message
from onyx.chat.prompt_utils import build_system_prompt
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
MAX_LLM_CYCLES = 6
def _build_context_file_citation_mapping(
file_metadata: list[ContextFileMetadata],
def _build_project_file_citation_mapping(
project_file_metadata: list[ProjectFileMetadata],
starting_citation_num: int = 1,
) -> CitationMapping:
"""Build citation mapping for context files.
"""Build citation mapping for project files.
Converts context file metadata into SearchDoc objects that can be cited.
Converts project file metadata into SearchDoc objects that can be cited.
Citation numbers start from the provided starting number.
Args:
file_metadata: List of context file metadata
project_file_metadata: List of project file metadata
starting_citation_num: Starting citation number (default: 1)
Returns:
@@ -221,7 +221,8 @@ def _build_context_file_citation_mapping(
"""
citation_mapping: CitationMapping = {}
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
# Create a SearchDoc for each project file
search_doc = SearchDoc(
document_id=file_meta.file_id,
chunk_ind=0,
@@ -241,28 +242,29 @@ def _build_context_file_citation_mapping(
def _build_project_message(
context_files: ExtractedContextFiles | None,
project_files: ExtractedProjectFiles | None,
token_counter: Callable[[str], int] | None,
) -> list[ChatMessageSimple]:
"""Build messages for context-injected / tool-backed files.
"""Build messages for project / tool-backed files.
Returns up to two messages:
1. The full-text files message (if file_texts is populated).
1. The full-text project files message (if project_file_texts is populated).
2. A lightweight metadata message for files the LLM should access via the
FileReaderTool (e.g. oversized files that don't fit in context).
FileReaderTool (e.g. oversized chat-attached files or project files that
don't fit in context).
"""
if not context_files:
if not project_files:
return []
messages: list[ChatMessageSimple] = []
if context_files.file_texts:
if project_files.project_file_texts:
messages.append(
_create_context_files_message(context_files, token_counter=None)
_create_project_files_message(project_files, token_counter=None)
)
if context_files.file_metadata_for_tool and token_counter:
if project_files.file_metadata_for_tool and token_counter:
messages.append(
_create_file_tool_metadata_message(
context_files.file_metadata_for_tool, token_counter
project_files.file_metadata_for_tool, token_counter
)
)
return messages
@@ -273,7 +275,7 @@ def construct_message_history(
custom_agent_prompt: ChatMessageSimple | None,
simple_chat_history: list[ChatMessageSimple],
reminder_message: ChatMessageSimple | None,
context_files: ExtractedContextFiles | None,
project_files: ExtractedProjectFiles | None,
available_tokens: int,
last_n_user_messages: int | None = None,
token_counter: Callable[[str], int] | None = None,
@@ -287,7 +289,7 @@ def construct_message_history(
# Build the project / file-metadata messages up front so we can use their
# actual token counts for the budget.
project_messages = _build_project_message(context_files, token_counter)
project_messages = _build_project_message(project_files, token_counter)
project_messages_tokens = sum(m.token_count for m in project_messages)
history_token_budget = available_tokens
@@ -443,17 +445,17 @@ def construct_message_history(
)
# Attach project images to the last user message
if context_files and context_files.image_files:
if project_files and project_files.project_image_files:
existing_images = last_user_message.image_files or []
last_user_message = ChatMessageSimple(
message=last_user_message.message,
token_count=last_user_message.token_count,
message_type=last_user_message.message_type,
image_files=existing_images + context_files.image_files,
image_files=existing_images + project_files.project_image_files,
)
# Build the final message list according to README ordering:
# [system], [history_before_last_user], [custom_agent], [context_files],
# [system], [history_before_last_user], [custom_agent], [project_files],
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
result = [system_prompt] if system_prompt else []
@@ -464,14 +466,14 @@ def construct_message_history(
if custom_agent_prompt:
result.append(custom_agent_prompt)
# 3. Add context files / file-metadata messages (inserted before last user message)
# 3. Add project files / file-metadata messages (inserted before last user message)
result.extend(project_messages)
# 4. Add forgotten-files metadata (right before the user's question)
if forgotten_files_message:
result.append(forgotten_files_message)
# 5. Add last user message (with context images attached)
# 5. Add last user message (with project images attached)
result.append(last_user_message)
# 6. Add messages after last user message (tool calls, responses, etc.)
@@ -545,11 +547,11 @@ def _create_file_tool_metadata_message(
)
def _create_context_files_message(
context_files: ExtractedContextFiles,
def _create_project_files_message(
project_files: ExtractedProjectFiles,
token_counter: Callable[[str], int] | None, # noqa: ARG001
) -> ChatMessageSimple:
"""Convert context files to a ChatMessageSimple message.
"""Convert project files to a ChatMessageSimple message.
Format follows the README specification for document representation.
"""
@@ -557,7 +559,7 @@ def _create_context_files_message(
# Format as documents JSON as described in README
documents_list = []
for idx, file_text in enumerate(context_files.file_texts, start=1):
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
documents_list.append(
{
"document": idx,
@@ -568,10 +570,10 @@ def _create_context_files_message(
documents_json = json.dumps({"documents": documents_list}, indent=2)
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
# Use pre-calculated token count from context_files
# Use pre-calculated token count from project_files
return ChatMessageSimple(
message=message_content,
token_count=context_files.total_token_count,
token_count=project_files.total_token_count,
message_type=MessageType.USER,
)
@@ -582,7 +584,7 @@ def run_llm_loop(
simple_chat_history: list[ChatMessageSimple],
tools: list[Tool],
custom_agent_prompt: str | None,
context_files: ExtractedContextFiles,
project_files: ExtractedProjectFiles,
persona: Persona | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
@@ -625,9 +627,9 @@ def run_llm_loop(
# Add project file citation mappings if project files are present
project_citation_mapping: CitationMapping = {}
if context_files.file_metadata:
project_citation_mapping = _build_context_file_citation_mapping(
context_files.file_metadata
if project_files.project_file_metadata:
project_citation_mapping = _build_project_file_citation_mapping(
project_files.project_file_metadata
)
citation_processor.update_citation_mapping(project_citation_mapping)
@@ -645,7 +647,7 @@ def run_llm_loop(
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
# One future workaround is to include the images as separate user messages with citation information and process those.
always_cite_documents: bool = bool(
context_files.use_as_search_filter or context_files.file_texts
project_files.project_as_filter or project_files.project_file_texts
)
should_cite_documents: bool = False
ran_image_gen: bool = False
@@ -786,7 +788,7 @@ def run_llm_loop(
custom_agent_prompt=custom_agent_prompt_msg,
simple_chat_history=simple_chat_history,
reminder_message=reminder_msg,
context_files=context_files,
project_files=project_files,
available_tokens=available_tokens,
token_counter=token_counter,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -31,6 +31,13 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ProjectSearchConfig(BaseModel):
"""Configuration for search tool availability in project context."""
search_usage: SearchToolUsage
disable_forced_tool: bool
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
@@ -125,8 +132,8 @@ class ChatMessageSimple(BaseModel):
file_id: str | None = None
class ContextFileMetadata(BaseModel):
"""Metadata for a context-injected file to enable citation support."""
class ProjectFileMetadata(BaseModel):
"""Metadata for a project file to enable citation support."""
file_id: str
filename: str
@@ -160,28 +167,20 @@ class ChatHistoryResult(BaseModel):
all_injected_file_metadata: dict[str, FileToolMetadata]
class ExtractedContextFiles(BaseModel):
"""Result of attempting to load user files (from a project or persona) into context."""
file_texts: list[str]
image_files: list[ChatLoadedFile]
use_as_search_filter: bool
class ExtractedProjectFiles(BaseModel):
project_file_texts: list[str]
project_image_files: list[ChatLoadedFile]
project_as_filter: bool
total_token_count: int
# Metadata for project files to enable citations
project_file_metadata: list[ProjectFileMetadata]
# None if not a project
project_uncapped_token_count: int | None
# Lightweight metadata for files exposed via FileReaderTool
# (populated when files don't fit in context and vector DB is disabled).
file_metadata: list[ContextFileMetadata]
uncapped_token_count: int | None
# (populated when files don't fit in context and vector DB is disabled)
file_metadata_for_tool: list[FileToolMetadata] = []
class SearchParams(BaseModel):
"""Resolved search filter IDs and search-tool usage for a chat turn."""
search_project_id: int | None
search_persona_id: int | None
search_usage: SearchToolUsage
class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None

View File

@@ -3,7 +3,6 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
An overview can be found in the README.md file in this directory.
"""
import io
import re
import traceback
from collections.abc import Callable
@@ -34,11 +33,11 @@ from onyx.chat.models import ChatBasicResponse
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import SearchParams
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ProjectSearchConfig
from onyx.chat.models import StreamingError
from onyx.chat.models import ToolCallResponse
from onyx.chat.prompt_utils import calculate_reserved_tokens
@@ -63,12 +62,11 @@ from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import get_project_token_count
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_for_persona
@@ -141,12 +139,12 @@ def _collect_available_file_ids(
pass
if project_id:
user_files = get_user_files_from_project(
project_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
for uf in user_files:
for uf in project_files:
user_file_ids.add(uf.id)
return _AvailableFiles(
@@ -194,67 +192,9 @@ def _convert_loaded_files_to_chat_files(
return chat_files
def resolve_context_user_files(
persona: Persona,
def _extract_project_file_texts_and_images(
project_id: int | None,
user_id: UUID | None,
db_session: Session,
) -> list[UserFile]:
"""Apply the precedence rule to decide which user files to load.
A custom persona fully supersedes the project. When a chat uses a
custom persona, the project is purely organisational — its files are
never loaded and never made searchable.
Custom persona → persona's own user_files (may be empty).
Default persona inside a project → project files.
Otherwise → empty list.
"""
if persona.id != DEFAULT_PERSONA_ID:
return list(persona.user_files) if persona.user_files else []
if project_id:
return get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return []
def _empty_extracted_context_files() -> ExtractedContextFiles:
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=False,
total_token_count=0,
file_metadata=[],
uncapped_token_count=None,
)
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
"""Extract text content from an InMemoryChatFile.
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
ingestion — decode directly.
DOC / CSV / other text types: the content is the original file bytes —
use extract_file_text which handles encoding detection and format parsing.
"""
try:
if f.file_type == ChatFileType.PLAIN_TEXT:
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
return extract_file_text(
file=io.BytesIO(f.content),
file_name=f.filename or "",
break_on_unprocessable=False,
)
except Exception:
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
return None
def extract_context_files(
user_files: list[UserFile],
llm_max_context_window: int,
reserved_token_count: int,
db_session: Session,
@@ -263,12 +203,8 @@ def extract_context_files(
# 60% of the LLM's max context window. The other benefit is that for projects with
# more files, this makes it so that we don't throw away the history too quickly every time.
max_llm_context_percentage: float = 0.6,
) -> ExtractedContextFiles:
"""Load user files into context if they fit; otherwise flag for search.
The caller is responsible for deciding *which* user files to pass in
(project files, persona files, etc.). This function only cares about
the all-or-nothing fit check and the actual content loading.
) -> ExtractedProjectFiles:
"""Extract text content from project files if they fit within the context window.
Args:
project_id: The project ID to load files from
@@ -277,95 +213,160 @@ def extract_context_files(
reserved_token_count: Number of tokens to reserve for other content
db_session: Database session
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
Returns:
ExtractedContextFiles containing:
- List of text content strings from context files (text files only)
- List of image files from context (ChatLoadedFile objects)
ExtractedProjectFiles containing:
- List of text content strings from project files (text files only)
- List of image files from project (ChatLoadedFile objects)
- Project id if the the project should be provided as a filter in search or None if not.
- Total token count of all extracted files
- File metadata for context files
- Uncapped token count of all extracted files
- File metadata for files that don't fit in context and vector DB is disabled
"""
# TODO(yuhong): I believe this is not handling all file types correctly.
# TODO I believe this is not handling all file types correctly.
project_as_filter = False
if not project_id:
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=None,
)
if not user_files:
return _empty_extracted_context_files()
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
max_actual_tokens = (
llm_max_context_window - reserved_token_count
) * max_llm_context_percentage
if aggregate_tokens >= max_actual_tokens:
tool_metadata = []
use_as_search_filter = not DISABLE_VECTOR_DB
if DISABLE_VECTOR_DB:
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=use_as_search_filter,
total_token_count=0,
file_metadata=[],
uncapped_token_count=aggregate_tokens,
file_metadata_for_tool=tool_metadata,
)
# Files fit — load them into context
user_file_map = {str(uf.id): uf for uf in user_files}
in_memory_files = load_in_memory_chat_files(
user_file_ids=[uf.id for uf in user_files],
# Calculate total token count for all user files in the project
project_tokens = get_project_token_count(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
file_texts: list[str] = []
image_files: list[ChatLoadedFile] = []
file_metadata: list[ContextFileMetadata] = []
project_file_texts: list[str] = []
project_image_files: list[ChatLoadedFile] = []
project_file_metadata: list[ProjectFileMetadata] = []
total_token_count = 0
if project_tokens < max_actual_tokens:
# Load project files into memory using cached plaintext when available
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
if project_user_files:
# Create a mapping from file_id to UserFile for token count lookup
user_file_map = {str(file.id): file for file in project_user_files}
for f in in_memory_files:
uf = user_file_map.get(str(f.file_id))
if f.file_type.is_text_file():
text_content = _extract_text_from_in_memory_file(f)
if not text_content:
continue
file_texts.append(text_content)
file_metadata.append(
ContextFileMetadata(
file_id=str(f.file_id),
filename=f.filename or f"file_{f.file_id}",
file_content=text_content,
)
)
if uf and uf.token_count:
total_token_count += uf.token_count
elif f.file_type == ChatFileType.IMAGE:
token_count = uf.token_count if uf and uf.token_count else 0
total_token_count += token_count
image_files.append(
ChatLoadedFile(
file_id=f.file_id,
content=f.content,
file_type=f.file_type,
filename=f.filename,
content_text=None,
token_count=token_count,
)
project_file_ids = [file.id for file in project_user_files]
in_memory_project_files = load_in_memory_chat_files(
user_file_ids=project_file_ids,
db_session=db_session,
)
return ExtractedContextFiles(
file_texts=file_texts,
image_files=image_files,
use_as_search_filter=False,
# Extract text content from loaded files
for file in in_memory_project_files:
if file.file_type.is_text_file():
try:
text_content = file.content.decode("utf-8", errors="ignore")
# Strip null bytes
text_content = text_content.replace("\x00", "")
if text_content:
project_file_texts.append(text_content)
# Add metadata for citation support
project_file_metadata.append(
ProjectFileMetadata(
file_id=str(file.file_id),
filename=file.filename or f"file_{file.file_id}",
file_content=text_content,
)
)
# Add token count for text file
user_file = user_file_map.get(str(file.file_id))
if user_file and user_file.token_count:
total_token_count += user_file.token_count
except Exception:
# Skip files that can't be decoded
pass
elif file.file_type == ChatFileType.IMAGE:
# Convert InMemoryChatFile to ChatLoadedFile
user_file = user_file_map.get(str(file.file_id))
token_count = (
user_file.token_count
if user_file and user_file.token_count
else 0
)
total_token_count += token_count
chat_loaded_file = ChatLoadedFile(
file_id=file.file_id,
content=file.content,
file_type=file.file_type,
filename=file.filename,
content_text=None, # Images don't have text content
token_count=token_count,
)
project_image_files.append(chat_loaded_file)
else:
if DISABLE_VECTOR_DB:
# Without a vector DB we can't use project-as-filter search.
# Instead, build lightweight metadata so the LLM can call the
# FileReaderTool to inspect individual files on demand.
file_metadata_for_tool = _build_file_tool_metadata_for_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=project_tokens,
file_metadata_for_tool=file_metadata_for_tool,
)
project_as_filter = True
return ExtractedProjectFiles(
project_file_texts=project_file_texts,
project_image_files=project_image_files,
project_as_filter=project_as_filter,
total_token_count=total_token_count,
file_metadata=file_metadata,
uncapped_token_count=aggregate_tokens,
project_file_metadata=project_file_metadata,
project_uncapped_token_count=project_tokens,
)
APPROX_CHARS_PER_TOKEN = 4
def _build_file_tool_metadata_for_project(
project_id: int,
user_id: UUID | None,
db_session: Session,
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata for every file in a project.
Used when files are too large to fit in context and the vector DB is
disabled, so the LLM needs to know which files it can read via the
FileReaderTool.
"""
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in project_user_files
]
def _build_file_tool_metadata_for_user_files(
user_files: list[UserFile],
) -> list[FileToolMetadata]:
@@ -380,46 +381,55 @@ def _build_file_tool_metadata_for_user_files(
]
def determine_search_params(
persona_id: int,
def _get_project_search_availability(
project_id: int | None,
extracted_context_files: ExtractedContextFiles,
) -> SearchParams:
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
persona_id: int | None,
loaded_project_files: bool,
project_has_files: bool,
forced_tool_id: int | None,
search_tool_id: int | None,
) -> ProjectSearchConfig:
"""Determine search tool availability based on project context.
A custom persona fully supersedes the project — project files are never
searchable and the search tool config is entirely controlled by the
persona. The project_id filter is only set for the default persona.
Search is disabled when ALL of the following are true:
- User is in a project
- Using the default persona (not a custom agent)
- Project files are already loaded in context
For the default persona inside a project:
- Files overflow → ENABLED (vector DB scopes to these files)
- Files fit → DISABLED (content already in prompt)
- No files at all → DISABLED (nothing to search)
When search is disabled and the user tried to force the search tool,
that forcing is also disabled.
Returns AUTO (follow persona config) in all other cases.
"""
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
# Not in a project, this should have no impact on search tool availability
if not project_id:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
search_project_id: int | None = None
search_persona_id: int | None = None
if extracted_context_files.use_as_search_filter:
if is_custom_persona:
search_persona_id = persona_id
else:
search_project_id = project_id
# Custom persona in project - let persona config decide
# Even if there are no files in the project, it's still guided by the persona config.
if persona_id != DEFAULT_PERSONA_ID:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
search_usage = SearchToolUsage.AUTO
if not is_custom_persona and project_id:
has_context_files = bool(extracted_context_files.uncapped_token_count)
files_loaded_in_context = bool(extracted_context_files.file_texts)
# If in a project with the default persona and the files have been already loaded into the context or
# there are no files in the project, disable search as there is nothing to search for.
if loaded_project_files or not project_has_files:
user_forced_search = (
forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
)
return ProjectSearchConfig(
search_usage=SearchToolUsage.DISABLED,
disable_forced_tool=user_forced_search,
)
if extracted_context_files.use_as_search_filter:
search_usage = SearchToolUsage.ENABLED
elif files_loaded_in_context or not has_context_files:
search_usage = SearchToolUsage.DISABLED
return SearchParams(
search_project_id=search_project_id,
search_persona_id=search_persona_id,
search_usage=search_usage,
# Default persona in a project with files, but also the files have not been loaded into the context already.
return ProjectSearchConfig(
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
)
@@ -651,37 +661,26 @@ def handle_stream_message_objects(
user_memory_context=prompt_memory_context,
)
# Determine which user files to use. A custom persona fully
# supersedes the project — project files are never loaded or
# searchable when a custom persona is in play. Only the default
# persona inside a project uses the project's files.
context_user_files = resolve_context_user_files(
persona=persona,
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
extracted_project_files = _extract_project_file_texts_and_images(
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=llm.config.max_input_tokens,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
# Also grant access to persona-attached user files for FileReaderTool
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# When the vector DB is disabled, persona-attached user_files have no
# search pipeline path. Inject them as file_metadata_for_tool so the
# LLM can read them via the FileReaderTool.
if DISABLE_VECTOR_DB and persona.user_files:
persona_file_metadata = _build_file_tool_metadata_for_user_files(
persona.user_files
)
# Merge persona file metadata into the extracted project files
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
# Build a mapping of tool_id to tool_name for history reconstruction
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
@@ -690,17 +689,30 @@ def handle_stream_message_objects(
None,
)
# Determine if search should be disabled for this project context
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
project_search_config = _get_project_search_availability(
project_id=chat_session.project_id,
persona_id=persona.id,
loaded_project_files=bool(extracted_project_files.project_file_texts),
project_has_files=bool(
extracted_project_files.project_uncapped_token_count
),
forced_tool_id=new_msg_req.forced_tool_id,
search_tool_id=search_tool_id,
)
if project_search_config.disable_forced_tool:
forced_tool_id = None
emitter = get_default_emitter()
# Also grant access to persona-attached user files
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
@@ -710,8 +722,11 @@ def handle_stream_message_objects(
llm=llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id=search_params.search_project_id,
persona_id=search_params.search_persona_id,
project_id=(
chat_session.project_id
if extracted_project_files.project_as_filter
else None
),
bypass_acl=bypass_acl,
slack_context=slack_context,
enable_slack_search=_should_enable_slack_search(
@@ -729,7 +744,7 @@ def handle_stream_message_objects(
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
search_usage_forcing_setting=project_search_config.search_usage,
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
@@ -768,7 +783,7 @@ def handle_stream_message_objects(
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
project_image_files=extracted_project_files.project_image_files,
additional_context=additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
@@ -864,54 +879,46 @@ def handle_stream_message_objects(
# (user has already responded to a clarification question)
skip_clarification = is_last_assistant_message_clarification(chat_history)
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
# are just passed in immediately anyways, but the abstraction is cleaner this way.
yield from run_chat_loop_with_state_containers(
lambda emitter, state_container: run_deep_research_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
),
run_deep_research_llm_loop,
llm_loop_completion_callback,
is_connected=check_is_connected,
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
)
else:
yield from run_chat_loop_with_state_containers(
lambda emitter, state_container: run_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
),
run_llm_loop,
llm_loop_completion_callback,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
except ValueError as e:

View File

@@ -23,6 +23,7 @@ from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import pkcs12
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
@@ -871,56 +872,6 @@ class SharepointConnector(
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
)
def _extract_tenant_domain_from_sites(self) -> str | None:
"""Extract the tenant domain from configured site URLs.
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
tenant domain is the first label of the hostname.
"""
for site_url in self.sites:
try:
hostname = urlsplit(site_url.strip()).hostname
except ValueError:
continue
if not hostname:
continue
tenant = hostname.split(".")[0]
if tenant:
return tenant
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
return None
def _resolve_tenant_domain_from_root_site(self) -> str:
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
Sites.Read.All (a permission the connector already needs)."""
root_site = self.graph_client.sites.root.get().execute_query()
hostname = root_site.site_collection.hostname
if not hostname:
raise ConnectorValidationError(
"Could not determine tenant domain from root site"
)
tenant_domain = hostname.split(".")[0]
logger.info(
"Resolved tenant domain '%s' from root site hostname '%s'",
tenant_domain,
hostname,
)
return tenant_domain
def _resolve_tenant_domain(self) -> str:
"""Determine the tenant domain, preferring site URLs over a Graph API
call to avoid needing extra permissions."""
from_sites = self._extract_tenant_domain_from_sites()
if from_sites:
logger.info(
"Resolved tenant domain '%s' from site URLs",
from_sites,
)
return from_sites
logger.info("No site URLs available; resolving tenant domain from root site")
return self._resolve_tenant_domain_from_root_site()
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
@@ -1638,11 +1589,6 @@ class SharepointConnector(
sp_private_key = credentials.get("sp_private_key")
sp_certificate_password = credentials.get("sp_certificate_password")
if not sp_client_id:
raise ConnectorValidationError("Client ID is required")
if not sp_directory_id:
raise ConnectorValidationError("Directory (tenant) ID is required")
authority_url = f"{self.authority_host}/{sp_directory_id}"
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
@@ -1695,7 +1641,21 @@ class SharepointConnector(
_acquire_token_for_graph, environment=self._azure_environment
)
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
self.sp_tenant_domain = self._resolve_tenant_domain()
org = self.graph_client.organization.get().execute_query()
if not org or len(org) == 0:
raise ConnectorValidationError("No organization found")
tenant_info: Organization = org[
0
] # Access first item directly from collection
if not tenant_info.verified_domains:
raise ConnectorValidationError("No verified domains found for tenant")
sp_tenant_domain = tenant_info.verified_domains[0].name
if not sp_tenant_domain:
raise ConnectorValidationError("No verified domains found for tenant")
# remove the .onmicrosoft.com part
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
return None
def _get_drive_names_for_site(self, site_url: str) -> list[str]:

View File

@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine.iam_auth import ssl_context
from onyx.db.engine.sql_engine import ASYNC_DB_API
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import is_valid_schema_name
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = create_ssl_context_if_iam()
connect_args["ssl"] = ssl_context
engine_kwargs = {
"connect_args": connect_args,
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = create_ssl_context_if_iam()
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE

View File

@@ -1,4 +1,3 @@
import functools
import os
import ssl
from typing import Any
@@ -49,9 +48,11 @@ def provide_iam_token(
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@functools.cache
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()

View File

@@ -256,6 +256,9 @@ def create_update_persona(
try:
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
# Curators can edit default personas, but not make them
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
pass
@@ -332,7 +335,6 @@ def update_persona_shared(
db_session: Session,
group_ids: list[int] | None = None,
is_public: bool | None = None,
label_ids: list[int] | None = None,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
@@ -342,7 +344,9 @@ def update_persona_shared(
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise PermissionError("You don't have permission to modify this persona")
raise HTTPException(
status_code=403, detail="You don't have permission to modify this persona"
)
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
@@ -356,15 +360,6 @@ def update_persona_shared(
group_ids=group_ids,
)
if label_ids is not None:
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
persona.labels.clear()
persona.labels = labels
db_session.commit()
@@ -970,8 +965,6 @@ def upsert_persona(
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
# Fetch and attach hierarchy_nodes by IDs
hierarchy_nodes = None
@@ -1168,6 +1161,9 @@ def update_persona_is_default(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if not persona.is_public:
persona.is_public = True
persona.is_default_persona = is_default
db_session.commit()

View File

@@ -6,7 +6,6 @@ from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import Project__UserFile
from onyx.db.models import UserFile
@@ -58,19 +57,12 @@ def fetch_user_project_ids_for_user_files(
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch user project ids for specified user files"""
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
)
rows = db_session.execute(stmt).all()
user_file_id_to_project_ids: dict[str, list[int]] = {
user_file_id: [] for user_file_id in user_file_ids
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [project.id for project in user_file.projects]
for user_file in results
}
for user_file_id, project_id in rows:
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
return user_file_id_to_project_ids
def fetch_persona_ids_for_user_files(

View File

@@ -139,7 +139,7 @@ def generate_final_report(
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
context_files=None,
project_files=None,
available_tokens=llm.config.max_input_tokens,
all_injected_file_metadata=all_injected_file_metadata,
)
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history + [reminder_message],
reminder_message=None,
context_files=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
all_injected_file_metadata=all_injected_file_metadata,
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=first_cycle_reminder_message,
context_files=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -405,7 +405,6 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID] | None = None
group_ids: list[int] | None = None
is_public: bool | None = None
label_ids: list[int] | None = None
# We notify each user when a user is shared with them
@@ -416,22 +415,14 @@ def share_persona(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
label_ids=persona_share_request.label_ids,
)
except PermissionError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=400, detail=str(e))
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
)
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)

View File

@@ -1,27 +0,0 @@
"""Per-tenant request counter metric.
Increments a counter on every request, labelled by tenant, so Grafana can
answer "which tenant is generating the most traffic?"
"""
from prometheus_client import Counter
from prometheus_fastapi_instrumentator.metrics import Info
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_requests_by_tenant = Counter(
"onyx_api_requests_by_tenant_total",
"Total API requests by tenant",
["tenant_id", "method", "handler", "status"],
)
def per_tenant_request_callback(info: Info) -> None:
"""Increment per-tenant request counter for every request."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
_requests_by_tenant.labels(
tenant_id=tenant_id,
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()

View File

@@ -32,7 +32,6 @@ from sqlalchemy.pool import QueuePool
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -73,7 +72,7 @@ _checkout_timeout_total = Counter(
_connections_held = Gauge(
"onyx_db_connections_held_by_endpoint",
"Number of DB connections currently held, by endpoint and engine",
["handler", "engine", "tenant_id"],
["handler", "engine"],
)
_hold_seconds = Histogram(
@@ -164,14 +163,10 @@ def _register_pool_events(engine: Engine, label: str) -> None:
conn_proxy: PoolProxiedConnection, # noqa: ARG001
) -> None:
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
conn_record.info["_metrics_endpoint"] = handler
conn_record.info["_metrics_tenant_id"] = tenant_id
conn_record.info["_metrics_checkout_time"] = time.monotonic()
_checkout_total.labels(engine=label).inc()
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).inc()
_connections_held.labels(handler=handler, engine=label).inc()
@event.listens_for(engine, "checkin")
def on_checkin(
@@ -179,12 +174,9 @@ def _register_pool_events(engine: Engine, label: str) -> None:
conn_record: ConnectionPoolEntry,
) -> None:
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
_checkin_total.labels(engine=label).inc()
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).dec()
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler, engine=label).observe(
time.monotonic() - start
@@ -207,12 +199,9 @@ def _register_pool_events(engine: Engine, label: str) -> None:
# Defensively clean up the held-connections gauge in case checkin
# doesn't fire after invalidation (e.g. hard pool shutdown).
handler = conn_record.info.pop("_metrics_endpoint", None)
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
if handler:
_connections_held.labels(
handler=handler, engine=label, tenant_id=tenant_id
).dec()
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
time.monotonic() - start

View File

@@ -11,11 +11,9 @@ SQLAlchemy connection pool metrics are registered separately via
"""
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
from sqlalchemy.exc import TimeoutError as SATimeoutError
from starlette.applications import Starlette
from onyx.server.metrics.per_tenant import per_tenant_request_callback
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
from onyx.server.metrics.slow_requests import slow_request_callback
@@ -61,15 +59,6 @@ def setup_prometheus_metrics(app: Starlette) -> None:
excluded_handlers=_EXCLUDED_HANDLERS,
)
# Explicitly create the default metrics (http_requests_total,
# http_request_duration_seconds, etc.) and add them first. The library
# skips creating defaults when ANY custom instrumentations are registered
# via .add(), so we must include them ourselves.
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
if default_callback:
instrumentator.add(default_callback)
instrumentator.add(slow_request_callback)
instrumentator.add(per_tenant_request_callback)
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)

View File

@@ -120,7 +120,7 @@ def generate_intermediate_report(
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
context_files=None,
project_files=None,
available_tokens=llm.config.max_input_tokens,
)
@@ -325,7 +325,7 @@ def run_research_agent_call(
custom_agent_prompt=None,
simple_chat_history=msg_history,
reminder_message=reminder_message,
context_files=None,
project_files=None,
available_tokens=llm.config.max_input_tokens,
)

View File

@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.7.3
pypdf==6.6.2
# via
# onyx
# unstructured-client

View File

@@ -12,7 +12,6 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.enums import HierarchyNodeType
from tests.daily.connectors.utils import load_all_from_connector
@@ -522,46 +521,3 @@ def test_sharepoint_connector_hierarchy_nodes(
f"Document {doc.semantic_identifier} should have "
"parent_hierarchy_raw_node_id set"
)
@pytest.fixture
def sharepoint_cert_credentials() -> dict[str, str]:
return {
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
"sp_certificate_password": os.environ[
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
],
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
}
def test_resolve_tenant_domain_from_site_urls(
sharepoint_cert_credentials: dict[str, str],
) -> None:
"""Verify that certificate auth resolves the tenant domain from site URLs
without calling the /organization endpoint."""
site_url = os.environ["SHAREPOINT_SITE"]
connector = SharepointConnector(sites=[site_url])
connector.load_credentials(sharepoint_cert_credentials)
assert connector.sp_tenant_domain is not None
assert len(connector.sp_tenant_domain) > 0
# The tenant domain should match the first label of the site URL hostname
from urllib.parse import urlsplit
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
assert connector.sp_tenant_domain == expected
def test_resolve_tenant_domain_from_root_site(
sharepoint_cert_credentials: dict[str, str],
) -> None:
"""Verify that certificate auth resolves the tenant domain via the root
site endpoint when no site URLs are configured."""
connector = SharepointConnector(sites=[])
connector.load_credentials(sharepoint_cert_credentials)
assert connector.sp_tenant_domain is not None
assert len(connector.sp_tenant_domain) > 0

View File

@@ -1,544 +0,0 @@
"""
External dependency unit tests for persona file sync.
Validates that:
1. The check_for_user_file_project_sync beat task picks up UserFiles with
needs_persona_sync=True (not just needs_project_sync).
2. The process_single_user_file_project_sync worker task reads persona
associations from the DB, passes persona_ids to the document index via
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
3. upsert_persona correctly marks affected UserFiles with
needs_persona_sync=True when file associations change.
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
since we only need to verify the arguments passed to update_single.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_for_user_file_project_sync,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file_project_sync,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
user_file_project_sync_lock_key,
)
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import upsert_persona
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _create_completed_user_file(
db_session: Session,
user: User,
needs_persona_sync: bool = False,
needs_project_sync: bool = False,
) -> UserFile:
"""Insert a UserFile in COMPLETED status."""
uf = UserFile(
id=uuid4(),
user_id=user.id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.COMPLETED,
needs_persona_sync=needs_persona_sync,
needs_project_sync=needs_project_sync,
chunk_count=5,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
def _create_test_persona(
db_session: Session,
user: User,
user_files: list[UserFile] | None = None,
) -> Persona:
"""Create a minimal Persona via direct model insert."""
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a test assistant",
task_prompt="Answer the question",
tools=[],
document_sets=[],
users=[user],
groups=[],
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,
deleted=False,
user_files=user_files or [],
user_id=user.id,
)
db_session.add(persona)
db_session.commit()
db_session.refresh(persona)
return persona
def _link_file_to_persona(
db_session: Session, persona: Persona, user_file: UserFile
) -> None:
"""Create the join table row between a persona and a user file."""
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
db_session.add(link)
db_session.commit()
_PATCH_QUEUE_DEPTH = (
"onyx.background.celery.tasks.user_file_processing.tasks"
".get_user_file_project_sync_queue_depth"
)
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on a bound Celery task."""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
):
yield
# ---------------------------------------------------------------------------
# Test: check_for_user_file_project_sync picks up persona sync
# ---------------------------------------------------------------------------
class TestCheckSweepIncludesPersonaSync:
"""The beat task must pick up files needing persona sync, not just project sync."""
def test_persona_sync_flag_enqueues_task(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
user = create_test_user(db_session, "persona_sweep")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
enqueued_ids = {
call.kwargs["kwargs"]["user_file_id"]
for call in mock_app.send_task.call_args_list
}
assert str(uf.id) in enqueued_ids
def test_neither_flag_does_not_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with both flags False is not enqueued."""
user = create_test_user(db_session, "no_sync")
uf = _create_completed_user_file(db_session, user)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
enqueued_ids = {
call.kwargs["kwargs"]["user_file_id"]
for call in mock_app.send_task.call_args_list
}
assert str(uf.id) not in enqueued_ids
def test_both_flags_enqueues_once(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file with BOTH flags True is enqueued exactly once."""
user = create_test_user(db_session, "both_flags")
uf = _create_completed_user_file(
db_session, user, needs_persona_sync=True, needs_project_sync=True
)
mock_app = MagicMock()
with _patch_task_app(check_for_user_file_project_sync, mock_app):
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
matching_calls = [
call
for call in mock_app.send_task.call_args_list
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
]
assert len(matching_calls) == 1
# ---------------------------------------------------------------------------
# Test: process_single_user_file_project_sync passes persona_ids to index
# ---------------------------------------------------------------------------
_PATCH_GET_SETTINGS = (
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
)
_PATCH_GET_INDICES = (
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
)
_PATCH_HTTPX_INIT = (
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
)
_PATCH_DISABLE_VDB = (
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
)
class TestSyncTaskWritesPersonaIds:
"""The sync task reads persona associations and sends them to the index."""
def test_passes_persona_ids_to_update_single(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After linking a file to a persona, sync sends the persona ID."""
user = create_test_user(db_session, "sync_persona")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
mock_doc_index.update_single.assert_called_once()
call_args = mock_doc_index.update_single.call_args
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
assert user_fields.personas is not None
assert persona.id in user_fields.personas
assert call_args.args[0] == str(uf.id)
def test_clears_persona_sync_flag(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a successful sync the needs_persona_sync flag is cleared."""
user = create_test_user(db_session, "sync_clear")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with patch(_PATCH_DISABLE_VDB, True):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
db_session.refresh(uf)
assert uf.needs_persona_sync is False
def test_passes_both_project_and_persona_ids(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file linked to both a project and a persona gets both IDs."""
from onyx.db.models import Project__UserFile
from onyx.db.models import UserProject
user = create_test_user(db_session, "sync_both")
uf = _create_completed_user_file(
db_session, user, needs_persona_sync=True, needs_project_sync=True
)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
project = UserProject(user_id=user.id, name="test-project", instructions="")
db_session.add(project)
db_session.commit()
db_session.refresh(project)
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
db_session.add(link)
db_session.commit()
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
call_kwargs = mock_doc_index.update_single.call_args.kwargs
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
assert user_fields.personas is not None
assert user_fields.user_projects is not None
assert persona.id in user_fields.personas
assert project.id in user_fields.user_projects
# Both flags should be cleared
db_session.refresh(uf)
assert uf.needs_persona_sync is False
assert uf.needs_project_sync is False
def test_deleted_persona_excluded_from_ids(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
user = create_test_user(db_session, "sync_deleted")
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
persona = _create_test_persona(db_session, user)
_link_file_to_persona(db_session, persona, uf)
persona.deleted = True
db_session.commit()
mock_doc_index = MagicMock()
mock_search_settings = MagicMock()
mock_search_settings.primary = MagicMock()
mock_search_settings.secondary = None
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
lock_key = user_file_project_sync_lock_key(str(uf.id))
redis_client.delete(lock_key)
with (
patch(_PATCH_DISABLE_VDB, False),
patch(_PATCH_HTTPX_INIT),
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
):
process_single_user_file_project_sync.run(
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
)
call_kwargs = mock_doc_index.update_single.call_args.kwargs
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
assert user_fields.personas is not None
assert persona.id not in user_fields.personas
# ---------------------------------------------------------------------------
# Test: upsert_persona marks files for persona sync
# ---------------------------------------------------------------------------
class TestUpsertPersonaMarksSyncFlag:
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
def test_creating_persona_with_files_marks_sync(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "upsert_create")
uf = _create_completed_user_file(db_session, user)
assert uf.needs_persona_sync is False
upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf.id],
)
db_session.refresh(uf)
assert uf.needs_persona_sync is True
def test_updating_persona_files_marks_both_old_and_new(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When file associations change, both the removed and added files are flagged."""
user = create_test_user(db_session, "upsert_update")
uf_old = _create_completed_user_file(db_session, user)
uf_new = _create_completed_user_file(db_session, user)
persona = upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf_old.id],
)
# Clear the flag from creation so we can observe the update
uf_old.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
# Now update the persona to swap files
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=None,
is_public=persona.is_public,
db_session=db_session,
persona_id=persona.id,
user_file_ids=[uf_new.id],
)
db_session.refresh(uf_old)
db_session.refresh(uf_new)
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
def test_removing_all_files_marks_old_files(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""Removing all files from a persona flags the previously associated files."""
user = create_test_user(db_session, "upsert_remove")
uf = _create_completed_user_file(db_session, user)
persona = upsert_persona(
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt="test",
task_prompt="test",
datetime_aware=None,
is_public=True,
db_session=db_session,
user_file_ids=[uf.id],
)
uf.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=None,
is_public=persona.is_public,
db_session=db_session,
persona_id=persona.id,
user_file_ids=[],
)
db_session.refresh(uf)
assert uf.needs_persona_sync is True

View File

@@ -1,318 +0,0 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
Mocks the LLM tokenizer and file store since they are not relevant here.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.models import UserProject
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _create_user_file(db_session: Session, user: User) -> UserFile:
uf = UserFile(
id=uuid4(),
user_id=user.id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.COMPLETED,
chunk_count=1,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
def _create_persona(db_session: Session, user: User) -> Persona:
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="test",
task_prompt="test",
tools=[],
document_sets=[],
users=[user],
groups=[],
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,
deleted=False,
user_id=user.id,
)
db_session.add(persona)
db_session.commit()
db_session.refresh(persona)
return persona
def _create_project(db_session: Session, user: User) -> UserProject:
project = UserProject(
user_id=user.id,
name=f"project-{uuid4().hex[:8]}",
instructions="",
)
db_session.add(project)
db_session.commit()
db_session.refresh(project)
return project
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
doc = Document(
id=str(user_file.id),
source=DocumentSource.USER_FILE,
semantic_identifier=user_file.name,
sections=[TextSection(text="test chunk content", link=None)],
metadata={},
)
return IndexChunk(
source_document=doc,
chunk_id=0,
blurb="test chunk",
content="test chunk content",
source_links={0: ""},
image_file_id=None,
section_continuation=False,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.0] * 768,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAdapterWritesBothMetadataFields:
"""build_metadata_aware_chunks must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_persona_gets_persona_id(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_persona")
uf = _create_user_file(db_session, user)
persona = _create_persona(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_project_gets_project_id(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_project")
uf = _create_user_file(db_session, user)
project = _create_project(db_session, user)
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_linked_to_both_gets_both_ids(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_both")
uf = _create_user_file(db_session, user)
persona = _create_persona(db_session, user)
project = _create_project(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_file_with_no_associations_gets_empty_lists(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
user = create_test_user(db_session, "adapter_empty")
uf = _create_user_file(db_session, user)
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in test"),
)
def test_multiple_personas_all_appear(
self,
_mock_llm: MagicMock,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file linked to multiple personas should have all their IDs."""
user = create_test_user(db_session, "adapter_multi")
uf = _create_user_file(db_session, user)
persona_a = _create_persona(db_session, user)
persona_b = _create_persona(db_session, user)
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
db_session.commit()
adapter = UserFileIndexingAdapter(
tenant_id=TEST_TENANT_ID, db_session=db_session
)
chunk = _make_index_chunk(uf)
context = DocumentBatchPrepareContext(
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
context=context,
)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -76,12 +76,9 @@ class ChatSessionManager:
user_performing_action: DATestUser,
persona_id: int = 0,
description: str = "Test chat session",
project_id: int | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id,
description=description,
project_id=project_id,
persona_id=persona_id, description=description
)
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",

View File

@@ -1,79 +0,0 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestScimToken
from tests.integration.common_utils.test_models import DATestUser
class ScimTokenManager:
@staticmethod
def create(
name: str,
user_performing_action: DATestUser,
) -> DATestScimToken:
response = requests.post(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
json={"name": name},
headers=user_performing_action.headers,
timeout=60,
)
response.raise_for_status()
data = response.json()
return DATestScimToken(
id=data["id"],
name=data["name"],
token_display=data["token_display"],
is_active=data["is_active"],
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
raw_token=data["raw_token"],
)
@staticmethod
def get_active(
user_performing_action: DATestUser,
) -> DATestScimToken | None:
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=user_performing_action.headers,
timeout=60,
)
if response.status_code == 404:
return None
response.raise_for_status()
data = response.json()
return DATestScimToken(
id=data["id"],
name=data["name"],
token_display=data["token_display"],
is_active=data["is_active"],
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
)
@staticmethod
def get_scim_headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def scim_get(
path: str,
raw_token: str,
) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimTokenManager.get_scim_headers(raw_token),
timeout=60,
)
@staticmethod
def scim_get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -42,18 +42,6 @@ class DATestPAT(BaseModel):
last_used_at: str | None = None
class DATestScimToken(BaseModel):
"""SCIM bearer token model for testing."""
id: int
name: str
raw_token: str | None = None # Only present on initial creation
token_display: str
is_active: bool
created_at: str
last_used_at: str | None = None
class DATestAPIKey(BaseModel):
api_key_id: int
api_key_display: str

View File

@@ -23,8 +23,6 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
@@ -36,8 +34,6 @@ class NightlyProviderConfig(BaseModel):
model_names: list[str]
api_key: str | None
api_base: str | None
api_version: str | None
deployment_name: str | None
custom_config: dict[str, str] | None
strict: bool
@@ -49,29 +45,17 @@ def _env_true(env_var: str, default: bool = False) -> bool:
return value.strip().lower() in {"1", "true", "yes", "on"}
def _parse_models_env(env_var: str) -> list[str]:
raw_value = os.environ.get(env_var, "").strip()
if not raw_value:
return []
try:
parsed_json = json.loads(raw_value)
except json.JSONDecodeError:
parsed_json = None
if isinstance(parsed_json, list):
return [str(model).strip() for model in parsed_json if str(model).strip()]
return [part.strip() for part in raw_value.split(",") if part.strip()]
def _split_csv_env(env_var: str) -> list[str]:
return [
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
]
def _load_provider_config() -> NightlyProviderConfig:
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
model_names = _parse_models_env(_ENV_MODELS)
model_names = _split_csv_env(_ENV_MODELS)
api_key = os.environ.get(_ENV_API_KEY) or None
api_base = os.environ.get(_ENV_API_BASE) or None
api_version = os.environ.get(_ENV_API_VERSION) or None
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
strict = _env_true(_ENV_STRICT, default=False)
custom_config: dict[str, str] | None = None
@@ -90,8 +74,6 @@ def _load_provider_config() -> NightlyProviderConfig:
model_names=model_names,
api_key=api_key,
api_base=api_base,
api_version=api_version,
deployment_name=deployment_name,
custom_config=custom_config,
strict=strict,
)
@@ -113,15 +95,10 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
message=f"{_ENV_MODELS} must include at least one model",
)
if config.provider != "ollama_chat" and not (
config.api_key or config.custom_config
):
if config.provider != "ollama_chat" and not config.api_key:
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
f"provider '{config.provider}'"
),
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
)
if config.provider == "ollama_chat" and not (
@@ -132,22 +109,6 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
)
if config.provider == "azure":
if not config.api_base:
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
),
)
if not config.api_version:
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
),
)
def _assert_integration_mode_enabled() -> None:
assert (
@@ -186,8 +147,6 @@ def _create_provider_payload(
model_name: str,
api_key: str | None,
api_base: str | None,
api_version: str | None,
deployment_name: str | None,
custom_config: dict[str, str] | None,
) -> dict:
return {
@@ -195,8 +154,6 @@ def _create_provider_payload(
"provider": provider,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
"deployment_name": deployment_name,
"custom_config": custom_config,
"default_model_name": model_name,
"is_public": True,
@@ -298,8 +255,6 @@ def _create_and_test_provider_for_model(
model_name=model_name,
api_key=config.api_key,
api_base=resolved_api_base,
api_version=config.api_version,
deployment_name=config.deployment_name,
custom_config=config.custom_config,
)
@@ -358,21 +313,10 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
_seed_connector_for_search_tool(admin_user)
search_tool_id = _get_internal_search_tool_id(admin_user)
failures: list[str] = []
for model_name in config.model_names:
try:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)
except BaseException as exc:
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
raise
failures.append(
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
)
if failures:
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)

View File

@@ -1,318 +0,0 @@
"""
Integration tests for the unified persona file context flow.
End-to-end tests that verify:
1. Files can be uploaded and attached to a persona via API.
2. The persona correctly reports its attached files.
3. A chat session with a file-bearing persona processes without error.
4. Precedence: custom persona files take priority over project files when
the chat session is inside a project.
These tests run against a real Onyx deployment (all services running).
File processing is asynchronous, so we poll the file status endpoint
until files reach COMPLETED before chatting.
"""
import time
import requests
from onyx.db.enums import UserFileStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.managers.project import ProjectManager
from tests.integration.common_utils.test_file_utils import create_test_text_file
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
FILE_PROCESSING_POLL_INTERVAL = 2
def _poll_file_statuses(
user_file_ids: list[str],
user: DATestUser,
target_status: UserFileStatus = UserFileStatus.COMPLETED,
timeout: int = MAX_DELAY,
) -> None:
"""Block until all files reach the target status or timeout expires."""
deadline = time.time() + timeout
while time.time() < deadline:
response = requests.post(
f"{API_SERVER_URL}/user/projects/file/statuses",
json={"file_ids": user_file_ids},
headers=user.headers,
)
response.raise_for_status()
statuses = response.json()
if all(f["status"] == target_status.value for f in statuses):
return
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
raise TimeoutError(
f"Files {user_file_ids} did not reach {target_status.value} "
f"within {timeout}s"
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_persona_with_files_chat_no_error(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Upload files, attach them to a persona, wait for processing,
then send a chat message. Verify no error is returned."""
# Upload files (creates UserFile records)
text_file = create_test_text_file(
"The secret project codename is NIGHTINGALE. "
"It was started in 2024 by the Advanced Research division."
)
file_descriptors, error = FileManager.upload_files(
files=[("nightingale_brief.txt", text_file)],
user_performing_action=admin_user,
)
assert not error, f"File upload failed: {error}"
assert len(file_descriptors) == 1
user_file_id = file_descriptors[0]["user_file_id"]
assert user_file_id is not None
# Wait for file processing
_poll_file_statuses([user_file_id], admin_user, timeout=120)
# Create persona with the file attached
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Nightingale Agent",
description="Agent with secret file",
system_prompt="You are a helpful assistant with access to uploaded files.",
user_file_ids=[user_file_id],
)
# Verify persona has the file
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
assert len(persona_snapshots) == 1
assert user_file_id in persona_snapshots[0].user_file_ids
# Chat with the persona
chat_session = ChatSessionManager.create(
persona_id=persona.id,
description="Test persona file context",
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the secret project codename?",
user_performing_action=admin_user,
)
assert response.error is None, f"Chat should succeed, got error: {response.error}"
assert len(response.full_message) > 0, "Response should not be empty"
def test_persona_without_files_still_works(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""A persona with no attached files should still chat normally."""
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Blank Agent",
description="No files attached",
system_prompt="You are a helpful assistant.",
)
chat_session = ChatSessionManager.create(
persona_id=persona.id,
description="Test blank persona",
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="Hello, how are you?",
user_performing_action=admin_user,
)
assert response.error is None
assert len(response.full_message) > 0
def test_persona_files_override_project_files(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When a custom persona (with its own files) is used inside a project,
the persona's files take precedence — the project's files are invisible.
We verify this by putting different content in project vs persona files
and checking which content the model responds with."""
# Upload persona file
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
persona_fds, err1 = FileManager.upload_files(
files=[("persona_secret.txt", persona_file)],
user_performing_action=admin_user,
)
assert not err1
persona_user_file_id = persona_fds[0]["user_file_id"]
assert persona_user_file_id is not None
# Create a project and upload project files
project = ProjectManager.create(
name="Precedence Test Project",
user_performing_action=admin_user,
)
project_files = [
("project_secret.txt", b"The project's secret word is FLAMINGO."),
]
project_upload_result = ProjectManager.upload_files(
project_id=project.id,
files=project_files,
user_performing_action=admin_user,
)
assert len(project_upload_result.user_files) == 1
project_user_file_id = str(project_upload_result.user_files[0].id)
# Wait for both persona and project file processing
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
# Create persona with persona file
persona = PersonaManager.create(
user_performing_action=admin_user,
name="Override Agent",
description="Persona with its own files",
system_prompt="You are a helpful assistant. Answer using the files.",
user_file_ids=[persona_user_file_id],
)
# Create chat session inside the project but using the custom persona
chat_session = ChatSessionManager.create(
persona_id=persona.id,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the secret word?",
user_performing_action=admin_user,
)
assert response.error is None, f"Chat should succeed, got error: {response.error}"
# The persona's file should be what the model sees, not the project's
message_lower = response.full_message.lower()
assert "albatross" in message_lower, (
"Response should reference the persona file's secret word (ALBATROSS), "
f"but got: {response.full_message}"
)
def test_default_persona_in_project_uses_project_files(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When the default persona (id=0) is used inside a project,
the project's files should be used for context."""
project = ProjectManager.create(
name="Default Persona Project",
user_performing_action=admin_user,
)
project_files = [
("project_info.txt", b"The project mascot is a PANGOLIN."),
]
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=project_files,
user_performing_action=admin_user,
)
assert len(upload_result.user_files) == 1
# Wait for project file processing
project_file_id = str(upload_result.user_files[0].id)
_poll_file_statuses([project_file_id], admin_user, timeout=120)
# Create chat session inside project using default persona (id=0)
chat_session = ChatSessionManager.create(
persona_id=0,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the project mascot?",
user_performing_action=admin_user,
)
assert response.error is None
assert "pangolin" in response.full_message.lower(), (
"Response should reference the project file content (PANGOLIN), "
f"but got: {response.full_message}"
)
def test_custom_persona_no_files_in_project_ignores_project(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""A custom persona with NO files, used inside a project with files,
should NOT see the project's files. The project is purely organizational.
We verify by asking about content only in the project file and checking
the model does NOT reference it."""
project = ProjectManager.create(
name="Ignored Project",
user_performing_action=admin_user,
)
project_upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
user_performing_action=admin_user,
)
assert len(project_upload_result.user_files) == 1
project_user_file_id = str(project_upload_result.user_files[0].id)
# Wait for project file processing
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
# Custom persona with no files
persona = PersonaManager.create(
user_performing_action=admin_user,
name="No Files Agent",
description="No files, project is irrelevant",
system_prompt=(
"You are a helpful assistant. If you do not have information "
"to answer a question, say 'I do not have that information.'"
),
)
chat_session = ChatSessionManager.create(
persona_id=persona.id,
project_id=project.id,
user_performing_action=admin_user,
)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="What is the project secret?",
user_performing_action=admin_user,
)
assert response.error is None
assert len(response.full_message) > 0
assert "capybara" not in response.full_message.lower(), (
"Response should NOT reference the project file content (CAPYBARA) "
"because the custom persona has no files and should not inherit "
f"project files, but got: {response.full_message}"
)

View File

@@ -1,166 +0,0 @@
"""Integration tests for SCIM token management.
Covers the admin token API and SCIM bearer-token authentication:
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
2. Token rotation: creating a new token revokes previous tokens
3. Revoked tokens are rejected by SCIM endpoints
4. Non-admin users cannot manage SCIM tokens
5. SCIM requests without a token are rejected
6. Service discovery endpoints work without authentication
7. last_used_at is updated after a SCIM request
"""
import time
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
"""Create token → retrieve metadata → use for SCIM request."""
token = ScimTokenManager.create(
name="Test SCIM Token",
user_performing_action=admin_user,
)
assert token.raw_token is not None
assert token.raw_token.startswith("onyx_scim_")
assert token.is_active is True
assert "****" in token.token_display
# GET returns the same metadata but raw_token is None because the
# server only reveals the raw token once at creation time (it stores
# only the SHA-256 hash).
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active == token.model_copy(update={"raw_token": None})
# Token works for SCIM requests
response = ScimTokenManager.scim_get("/Users", token.raw_token)
assert response.status_code == 200
body = response.json()
assert "Resources" in body
assert body["totalResults"] >= 0
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
"""Creating a new token automatically revokes the previous one."""
first = ScimTokenManager.create(
name="First Token",
user_performing_action=admin_user,
)
assert first.raw_token is not None
response = ScimTokenManager.scim_get("/Users", first.raw_token)
assert response.status_code == 200
# Create second token — should revoke first
second = ScimTokenManager.create(
name="Second Token",
user_performing_action=admin_user,
)
assert second.raw_token is not None
# Active token should now be the second one
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active == second.model_copy(update={"raw_token": None})
# First token rejected, second works
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
def test_scim_request_without_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with no Authorization header."""
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
def test_scim_request_with_bad_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with an invalid token."""
assert (
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
== 401
)
def test_non_admin_cannot_create_token(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Non-admin users get 403 when trying to create a SCIM token."""
basic_user = UserManager.create(name="scim_basic_user")
response = requests.post(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
json={"name": "Should Fail"},
headers=basic_user.headers,
timeout=60,
)
assert response.status_code == 403
def test_non_admin_cannot_get_token(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
basic_user = UserManager.create(name="scim_basic_user2")
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=basic_user.headers,
timeout=60,
)
assert response.status_code == 403
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
"""GET active token returns 404 when no token exists."""
# new_admin_user depends on the reset fixture, ensuring a clean DB
# with no active SCIM tokens.
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
assert active is None
response = requests.get(
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
headers=new_admin_user.headers,
timeout=60,
)
assert response.status_code == 404
def test_service_discovery_no_auth_required(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""Service discovery endpoints work without any authentication."""
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
response = ScimTokenManager.scim_get_no_auth(path)
assert response.status_code == 200, f"{path} returned {response.status_code}"
def test_last_used_at_updated_after_scim_request(
admin_user: DATestUser,
) -> None:
"""last_used_at timestamp is updated after using the token."""
token = ScimTokenManager.create(
name="Last Used Token",
user_performing_action=admin_user,
)
assert token.raw_token is not None
active = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active is not None
assert active.last_used_at is None
# Make a SCIM request, then verify last_used_at is set
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
time.sleep(0.5)
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
assert active_after is not None
assert active_after.last_used_at is not None

View File

@@ -1,426 +0,0 @@
"""Tests for the unified context file extraction logic (Phase 5).
Covers:
- resolve_context_user_files: precedence rule (custom persona supersedes project)
- extract_context_files: all-or-nothing context window fit check
- Search filter / search_usage determination in the caller
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.process_message import determine_search_params
from onyx.chat.process_message import extract_context_files
from onyx.chat.process_message import resolve_context_user_files
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.models import UserFile
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.tools.models import SearchToolUsage
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_user_file(
token_count: int = 100,
name: str = "file.txt",
file_id: str | None = None,
) -> UserFile:
file_uuid = UUID(file_id) if file_id else uuid4()
return UserFile(
id=file_uuid,
file_id=str(file_uuid),
name=name,
token_count=token_count,
)
def _make_persona(
persona_id: int,
user_files: list | None = None,
) -> MagicMock:
persona = MagicMock()
persona.id = persona_id
persona.user_files = user_files or []
return persona
def _make_in_memory_file(
file_id: str,
content: str = "hello world",
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
filename: str = "file.txt",
) -> InMemoryChatFile:
return InMemoryChatFile(
file_id=file_id,
content=content.encode("utf-8"),
file_type=file_type,
filename=filename,
)
# ===========================================================================
# resolve_context_user_files
# ===========================================================================
class TestResolveContextUserFiles:
"""Precedence rule: custom persona fully supersedes project."""
def test_custom_persona_with_files_returns_persona_files(self) -> None:
persona_files = [_make_user_file(), _make_user_file()]
persona = _make_persona(persona_id=42, user_files=persona_files)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == persona_files
def test_custom_persona_without_files_returns_empty(self) -> None:
"""Custom persona with no files should NOT fall through to project."""
persona = _make_persona(persona_id=42, user_files=[])
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
def test_custom_persona_none_files_returns_empty(self) -> None:
"""Custom persona with user_files=None should NOT fall through."""
persona = _make_persona(persona_id=42, user_files=None)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
@patch("onyx.chat.process_message.get_user_files_from_project")
def test_default_persona_in_project_returns_project_files(
self, mock_get_files: MagicMock
) -> None:
project_files = [_make_user_file(), _make_user_file()]
mock_get_files.return_value = project_files
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
user_id = uuid4()
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=user_id, db_session=db_session
)
assert result == project_files
mock_get_files.assert_called_once_with(
project_id=99, user_id=user_id, db_session=db_session
)
def test_default_persona_no_project_returns_empty(self) -> None:
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
)
assert result == []
@patch("onyx.chat.process_message.get_user_files_from_project")
def test_custom_persona_without_files_ignores_project(
self, mock_get_files: MagicMock
) -> None:
"""Even with a project_id, custom persona means project is invisible."""
persona = _make_persona(persona_id=7, user_files=[])
db_session = MagicMock()
result = resolve_context_user_files(
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
)
assert result == []
mock_get_files.assert_not_called()
# ===========================================================================
# extract_context_files
# ===========================================================================
class TestExtractContextFiles:
"""All-or-nothing context window fit check."""
def test_empty_user_files_returns_empty(self) -> None:
db_session = MagicMock()
result = extract_context_files(
user_files=[],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=db_session,
)
assert result.file_texts == []
assert result.image_files == []
assert result.use_as_search_filter is False
assert result.uncapped_token_count is None
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
file_id = str(uuid4())
uf = _make_user_file(token_count=100, file_id=file_id)
mock_load.return_value = [
_make_in_memory_file(file_id=file_id, content="file content")
]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == ["file content"]
assert result.use_as_search_filter is False
assert result.total_token_count == 100
assert len(result.file_metadata) == 1
assert result.file_metadata[0].file_id == file_id
def test_files_overflow_context_not_loaded(self) -> None:
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
uf = _make_user_file(token_count=7000)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == []
assert result.image_files == []
assert result.use_as_search_filter is True
assert result.uncapped_token_count == 7000
assert result.total_token_count == 0
def test_overflow_boundary_exact(self) -> None:
"""Token count exactly at the 60% boundary should trigger overflow."""
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
uf = _make_user_file(token_count=6000)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
"""Token count just under the 60% boundary should load files."""
file_id = str(uuid4())
uf = _make_user_file(token_count=5999, file_id=file_id)
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert result.file_texts == ["data"]
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
"""Multiple small files that individually fit but collectively overflow."""
files = [_make_user_file(token_count=2500) for _ in range(3)]
# 3 * 2500 = 7500 > 6000 threshold
result = extract_context_files(
user_files=files,
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
assert result.file_texts == []
mock_load.assert_not_called()
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
"""Reserved tokens shrink the available window."""
file_id = str(uuid4())
uf = _make_user_file(token_count=3000, file_id=file_id)
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=5000,
db_session=MagicMock(),
)
assert result.use_as_search_filter is True
mock_load.assert_not_called()
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
file_id = str(uuid4())
uf = _make_user_file(token_count=50, file_id=file_id)
mock_load.return_value = [
InMemoryChatFile(
file_id=file_id,
content=b"\x89PNG",
file_type=ChatFileType.IMAGE,
filename="photo.png",
)
]
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert len(result.image_files) == 1
assert result.image_files[0].file_id == file_id
assert result.file_texts == []
assert result.total_token_count == 50
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
"""When vector DB is disabled, overflow produces FileToolMetadata."""
uf = _make_user_file(token_count=7000, name="bigfile.txt")
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
# ===========================================================================
# Search filter + search_usage determination
# ===========================================================================
class TestSearchFilterDetermination:
"""Verify that determine_search_params correctly resolves
search_project_id, search_persona_id, and search_usage based on
the extraction result and the precedence rule.
"""
@staticmethod
def _make_context(
use_as_search_filter: bool = False,
file_texts: list[str] | None = None,
uncapped_token_count: int | None = None,
) -> ExtractedContextFiles:
return ExtractedContextFiles(
file_texts=file_texts or [],
image_files=[],
use_as_search_filter=use_as_search_filter,
total_token_count=0,
file_metadata=[],
uncapped_token_count=uncapped_token_count,
)
def test_custom_persona_files_fit_no_filter(self) -> None:
"""Custom persona, files fit → no search filter, AUTO."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(
file_texts=["content"],
uncapped_token_count=100,
),
)
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_files_overflow_persona_filter(self) -> None:
"""Custom persona, files overflow → persona_id filter, AUTO."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(use_as_search_filter=True),
)
assert result.search_persona_id == 42
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_no_files_no_project_leak(self) -> None:
"""Custom persona (no files) in project → nothing leaks from project."""
result = determine_search_params(
persona_id=42,
project_id=99,
extracted_context_files=self._make_context(),
)
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_files_fit_disables_search(self) -> None:
"""Default persona, project files fit → DISABLED."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(
file_texts=["content"],
uncapped_token_count=100,
),
)
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.DISABLED
def test_default_persona_project_files_overflow_enables_search(self) -> None:
"""Default persona, project files overflow → ENABLED + project_id filter."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(
use_as_search_filter=True,
uncapped_token_count=7000,
),
)
assert result.search_project_id == 99
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.ENABLED
def test_default_persona_no_project_auto(self) -> None:
"""Default persona, no project → AUTO."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=None,
extracted_context_files=self._make_context(),
)
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_no_files_disables_search(self) -> None:
"""Default persona in project with no files → DISABLED."""
result = determine_search_params(
persona_id=DEFAULT_PERSONA_ID,
project_id=99,
extracted_context_files=self._make_context(),
)
assert result.search_usage == SearchToolUsage.DISABLED

View File

@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
@@ -74,20 +74,20 @@ def create_tool_response(
)
def create_context_files(
def create_project_files(
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
) -> ExtractedContextFiles:
"""Helper to create ExtractedContextFiles for testing."""
file_texts = [f"Project file {i} content" for i in range(num_files)]
file_metadata = [
ContextFileMetadata(
) -> ExtractedProjectFiles:
"""Helper to create ExtractedProjectFiles for testing."""
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
project_file_metadata = [
ProjectFileMetadata(
file_id=f"file_{i}",
filename=f"file_{i}.txt",
file_content=f"Project file {i} content",
)
for i in range(num_files)
]
image_files = [
project_image_files = [
ChatLoadedFile(
file_id=f"image_{i}",
content=b"",
@@ -98,13 +98,13 @@ def create_context_files(
)
for i in range(num_images)
]
return ExtractedContextFiles(
file_texts=file_texts,
image_files=image_files,
use_as_search_filter=False,
return ExtractedProjectFiles(
project_file_texts=project_file_texts,
project_image_files=project_image_files,
project_as_filter=False,
total_token_count=num_files * tokens_per_file,
file_metadata=file_metadata,
uncapped_token_count=num_files * tokens_per_file,
project_file_metadata=project_file_metadata,
project_uncapped_token_count=num_files * tokens_per_file,
)
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
user_msg2 = create_message("How are you?", MessageType.USER, 5)
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
context_files = create_context_files()
project_files = create_project_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
context_files = create_context_files()
project_files = create_project_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
assert result[3] == custom_agent # Before last user message
assert result[4] == user_msg2
def test_with_context_files(self) -> None:
def test_with_project_files(self) -> None:
"""Test that project files are inserted before the last user message."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg1 = create_message("First message", MessageType.USER, 5)
user_msg2 = create_message("Second message", MessageType.USER, 5)
simple_chat_history = [user_msg1, user_msg2]
context_files = create_context_files(num_files=2, tokens_per_file=50)
project_files = create_project_files(num_files=2, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
# Should have: system, user1, context_files_message, user2
# Should have: system, user1, project_files_message, user2
assert len(result) == 4
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
simple_chat_history = [user_msg]
context_files = create_context_files()
project_files = create_project_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
assistant_with_tool,
tool_response,
]
context_files = create_context_files()
project_files = create_project_files()
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom", MessageType.USER, 10)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
context_files = create_context_files(num_files=1, tokens_per_file=50)
project_files = create_project_files(num_files=1, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
assert len(result) == 6
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
user_msg2 = create_message("Second", MessageType.USER, 5)
simple_chat_history = [user_msg1, user_msg2]
context_files = create_context_files(num_files=0, num_images=2)
project_files = create_project_files(num_files=0, num_images=2)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
)
simple_chat_history = [user_msg]
context_files = create_context_files(num_files=0, num_images=1)
project_files = create_project_files(num_files=0, num_images=1)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
assistant_msg2,
user_msg3,
]
context_files = create_context_files()
project_files = create_project_files()
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
result = construct_message_history(
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=80,
)
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
tool_response = create_tool_response("tc_1", "tool_response", 20)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
context_files = create_context_files()
project_files = create_project_files()
# Budget only allows last user message and messages after + system
# (10 + 20 + 20 + 20 = 70 tokens)
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=80,
)
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
assistant_msg1,
user_msg2,
]
context_files = create_context_files()
project_files = create_project_files()
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
# keeps [tool_response, assistant_msg1] from history_before_last_user,
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=30,
)
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
user_msg2 = create_message("Latest question", MessageType.USER, 10)
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
context_files = create_context_files()
project_files = create_project_files()
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
# keeps both assistant_with_tool and tool_response in history_before_last_user.
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=45,
)
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
reminder = create_message("Reminder", MessageType.USER, 10)
simple_chat_history: list[ChatMessageSimple] = []
context_files = create_context_files(num_files=1, tokens_per_file=50)
project_files = create_project_files(num_files=1, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
# Should have: system, custom_agent, context_files, reminder
# Should have: system, custom_agent, project_files, reminder
assert len(result) == 4
assert result[0] == system_prompt
assert result[1] == custom_agent
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
simple_chat_history = [assistant_msg, assistant_with_tool]
context_files = create_context_files()
project_files = create_project_files()
with pytest.raises(ValueError, match="No user message found"):
construct_message_history(
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
custom_agent = create_message("Custom", MessageType.USER, 50)
simple_chat_history = [user_msg]
context_files = create_context_files(num_files=1, tokens_per_file=100)
project_files = create_project_files(num_files=1, tokens_per_file=100)
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
# But only 200 available
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=200,
)
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
context_files = create_context_files()
project_files = create_project_files()
# Budget: 50 tokens
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=50,
)
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
assistant_with_tool,
tool_response,
]
context_files = create_context_files(num_files=2, tokens_per_file=20)
project_files = create_project_files(num_files=2, tokens_per_file=20)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=custom_agent,
simple_chat_history=simple_chat_history,
reminder_message=reminder,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
# Expected order:
# system, user1, assistant1, user2, assistant2,
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
assert len(result) == 11
assert result[0] == system_prompt
assert result[1] == user_msg1
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
assert result[9] == tool_response # After last user
assert result[10] == reminder # At the very end
def test_context_files_json_format(self) -> None:
def test_project_files_json_format(self) -> None:
"""Test that project files are formatted correctly as JSON."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Hello", MessageType.USER, 5)
simple_chat_history = [user_msg]
context_files = create_context_files(num_files=2, tokens_per_file=50)
project_files = create_project_files(num_files=2, tokens_per_file=50)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=context_files,
project_files=project_files,
available_tokens=1000,
)
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
context_files=create_context_files(),
project_files=create_project_files(),
available_tokens=available_tokens,
token_counter=_simple_token_counter,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -106,9 +106,6 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
patch(
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
) as mock_ctx,
patch(
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_tenant_ctx,
patch(
"onyx.server.metrics.postgres_connection_pool._connections_held"
) as mock_gauge,
@@ -117,14 +114,12 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
mock_labels = MagicMock()
mock_gauge.labels.return_value = mock_labels
mock_ctx.get.return_value = "/api/chat/send-message"
mock_tenant_ctx.get.return_value = "tenant_xyz"
listeners["checkout"](None, conn_record, None)
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
assert "_metrics_checkout_time" in conn_record.info
mock_gauge.labels.assert_called_with(
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
handler="/api/chat/send-message", engine="sync"
)
mock_labels.inc.assert_called_once()
@@ -149,7 +144,6 @@ def test_checkin_event_observes_hold_duration() -> None:
conn_record = _make_conn_record()
conn_record.info["_metrics_endpoint"] = "/api/search"
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
with (
@@ -168,9 +162,7 @@ def test_checkin_event_observes_hold_duration() -> None:
listeners["checkin"](None, conn_record)
mock_gauge.labels.assert_called_with(
handler="/api/search", engine="sync", tenant_id="tenant_abc"
)
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
mock_labels.dec.assert_called_once()
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
mock_hist_labels.observe.assert_called_once()
@@ -180,12 +172,11 @@ def test_checkin_event_observes_hold_duration() -> None:
# conn_record.info should be cleaned up
assert "_metrics_endpoint" not in conn_record.info
assert "_metrics_tenant_id" not in conn_record.info
assert "_metrics_checkout_time" not in conn_record.info
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
"""Verify checkin gracefully handles missing endpoint and tenant info."""
"""Verify checkin gracefully handles missing endpoint info."""
engine = MagicMock()
engine.pool = MagicMock()
listeners: dict[str, Any] = {}
@@ -216,9 +207,7 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
listeners["checkin"](None, conn_record)
mock_gauge.labels.assert_called_with(
handler="unknown", engine="sync", tenant_id="unknown"
)
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
# --- setup_postgres_connection_pool_metrics tests ---

View File

@@ -10,7 +10,6 @@ from fastapi.testclient import TestClient
from prometheus_client import CollectorRegistry
from prometheus_client import Gauge
from onyx.server.metrics.per_tenant import per_tenant_request_callback
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.metrics.slow_requests import slow_request_callback
@@ -82,7 +81,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
inprogress_labels=True,
excluded_handlers=["/health", "/metrics", "/openapi.json"],
)
assert mock_instance.add.call_count == 3
mock_instance.add.assert_called_once()
mock_instance.instrument.assert_called_once_with(
app,
latency_lowr_buckets=(
@@ -101,56 +100,6 @@ def test_setup_attaches_instrumentator_to_app() -> None:
mock_instance.expose.assert_called_once_with(app)
def test_per_tenant_callback_increments_with_tenant_id() -> None:
"""Verify per-tenant callback reads tenant from contextvar and increments."""
with (
patch(
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_ctx,
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
):
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels
mock_ctx.get.return_value = "tenant_abc"
info = _make_info(
duration=0.1, method="POST", handler="/api/chat", status="200"
)
per_tenant_request_callback(info)
mock_counter.labels.assert_called_once_with(
tenant_id="tenant_abc",
method="POST",
handler="/api/chat",
status="200",
)
mock_labels.inc.assert_called_once()
def test_per_tenant_callback_falls_back_to_unknown() -> None:
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
with (
patch(
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
) as mock_ctx,
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
):
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels
mock_ctx.get.return_value = None
info = _make_info(duration=0.1)
per_tenant_request_callback(info)
mock_counter.labels.assert_called_once_with(
tenant_id="unknown",
method="GET",
handler="/api/test",
status="200",
)
mock_labels.inc.assert_called_once()
def test_inprogress_gauge_increments_during_request() -> None:
"""Verify the in-progress gauge goes up while a request is in flight."""
registry = CollectorRegistry()

View File

@@ -163,16 +163,3 @@ Add clear comments:
- Any TODOs you add in the code must be accompanied by either the name/username
of the owner of that TODO, or an issue number for an issue referencing that
piece of work.
- Avoid module-level logic that runs on import, which leads to import-time side
effects. Essentially every piece of meaningful logic should exist within some
function that has to be explicitly invoked. Acceptable exceptions to this may
include loading environment variables or setting up loggers.
- If you find yourself needing something like this, you may want that logic to
exist in a file dedicated for manual execution (contains `if __name__ ==
"__main__":`) which should not be imported by anything else.
- Related to the above, do not conflate Python scripts you intend to run from
the command line (contains `if __name__ == "__main__":`) with modules you
intend to import from elsewhere. If for some unlikely reason they have to be
the same file, any logic specific to executing the file (including imports)
should be contained in the `if __name__ == "__main__":` block.
- Generally these executable files exist in `backend/scripts/`.

View File

@@ -534,10 +534,9 @@ services:
required: false
# Below is needed for the `docker-out-of-docker` execution mode
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
user: root
volumes:
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
- /var/run/docker.sock:/var/run/docker.sock
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
# privileged: true

View File

@@ -92,7 +92,7 @@ backend = [
"python-gitlab==5.6.0",
"python-pptx==0.6.23",
"pypandoc_binary==1.16.2",
"pypdf==6.7.3",
"pypdf==6.6.2",
"pytest-mock==3.12.0",
"pytest-playwright==0.7.0",
"python-docx==1.1.2",

8
uv.lock generated
View File

@@ -4677,7 +4677,7 @@ requires-dist = [
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.3" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.6.2" },
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
@@ -5924,11 +5924,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.7.3"
version = "6.6.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/53/9b/63e767042fc852384dc71e5ff6f990ee4e1b165b1526cf3f9c23a4eebb47/pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2", size = 5303304, upload-time = "2026-02-24T17:23:11.42Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/90/3308a9b8b46c1424181fdf3f4580d2b423c5471425799e7fc62f92d183f4/pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a", size = 331263, upload-time = "2026-02-24T17:23:09.932Z" },
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
]
[[package]]

View File

@@ -1,233 +0,0 @@
import "@opal/core/hoverable/styles.css";
import React, { createContext, useContext, useState, useCallback } from "react";
import { cn } from "@opal/utils";
import type { WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
// Context-per-group registry
// ---------------------------------------------------------------------------
/**
* Lazily-created map of group names to React contexts.
*
* Each group gets its own `React.Context<boolean | null>` so that a
* `Hoverable.Item` only re-renders when its *own* group's hover state
* changes — not when any unrelated group changes.
*
* The default value is `null` (no provider found), which lets
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
* not hovered" and throw when `group` was explicitly specified.
*/
const contextMap = new Map<string, React.Context<boolean | null>>();
function getOrCreateContext(group: string): React.Context<boolean | null> {
let ctx = contextMap.get(group);
if (!ctx) {
ctx = createContext<boolean | null>(null);
ctx.displayName = `HoverableContext(${group})`;
contextMap.set(group, ctx);
}
return ctx;
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HoverableRootProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group: string;
}
type HoverableItemVariant = "opacity-on-hover";
interface HoverableItemProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group?: string;
variant?: HoverableItemVariant;
}
// ---------------------------------------------------------------------------
// HoverableRoot
// ---------------------------------------------------------------------------
/**
* Hover-tracking container for a named group.
*
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
* provides the hover state via a per-group React context.
*
* Nesting works because each `Hoverable.Root` creates a **new** context
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
* reads from the inner provider, not the outer `group="a"` provider.
*
* @example
* ```tsx
* <Hoverable.Root group="card">
* <Card>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
* ```
*/
function HoverableRoot({
group,
children,
onMouseEnter: consumerMouseEnter,
onMouseLeave: consumerMouseLeave,
...props
}: HoverableRootProps) {
const [hovered, setHovered] = useState(false);
const onMouseEnter = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(true);
consumerMouseEnter?.(e);
},
[consumerMouseEnter]
);
const onMouseLeave = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(false);
consumerMouseLeave?.(e);
},
[consumerMouseLeave]
);
const GroupContext = getOrCreateContext(group);
return (
<GroupContext.Provider value={hovered}>
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{children}
</div>
</GroupContext.Provider>
);
}
// ---------------------------------------------------------------------------
// HoverableItem
// ---------------------------------------------------------------------------
/**
* An element whose visibility is controlled by hover state.
*
* **Local mode** (`group` omitted): the item handles hover on its own
* element via CSS `:hover`. This is the core abstraction.
*
* **Group mode** (`group` provided): visibility is driven by a matching
* `Hoverable.Root` ancestor's hover state via React context. If no
* matching Root is found, an error is thrown.
*
* Uses data-attributes for variant styling (see `styles.css`).
*
* @example
* ```tsx
* // Local mode — hover on the item itself
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
*
* // Group mode — hover on the Root reveals the item
* <Hoverable.Root group="card">
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Hoverable.Root>
* ```
*
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
*/
function HoverableItem({
group,
variant = "opacity-on-hover",
children,
...props
}: HoverableItemProps) {
const contextValue = useContext(
group ? getOrCreateContext(group) : NOOP_CONTEXT
);
if (group && contextValue === null) {
throw new Error(
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
);
}
const isLocal = group === undefined;
return (
<div
{...props}
className={cn("hoverable-item")}
data-hoverable-variant={variant}
data-hoverable-active={
isLocal ? undefined : contextValue ? "true" : undefined
}
data-hoverable-local={isLocal ? "true" : undefined}
>
{children}
</div>
);
}
/** Stable context used when no group is specified (local mode). */
const NOOP_CONTEXT = createContext<boolean | null>(null);
// ---------------------------------------------------------------------------
// Compound export
// ---------------------------------------------------------------------------
/**
* Hoverable compound component for hover-to-reveal patterns.
*
* Provides two sub-components:
*
* - `Hoverable.Root` — A container that tracks hover state for a named group
* and provides it via React context.
*
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
* applies local CSS `:hover` for the variant effect. When `group` is
* specified, it reads hover state from the nearest matching
* `Hoverable.Root` — and throws if no matching Root is found.
*
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
* so each group's items only respond to their own root's hover.
*
* @example
* ```tsx
* import { Hoverable } from "@opal/core";
*
* // Group mode — hovering the card reveals the trash icon
* <Hoverable.Root group="card">
* <Card>
* <span>Card content</span>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
*
* // Local mode — hovering the item itself reveals it
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* ```
*/
const Hoverable = {
Root: HoverableRoot,
Item: HoverableItem,
};
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
};

View File

@@ -1,18 +0,0 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 200ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
opacity: 0;
}
/* Group mode — Root controls visibility via React context */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
opacity: 1;
}
/* Local mode — item handles its own :hover */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
opacity: 1;
}

View File

@@ -1,11 +1,3 @@
/* Hoverable */
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
} from "@opal/core/hoverable/components";
/* Interactive */
export {
Interactive,

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgHandle = ({ size = 16, ...props }: IconProps) => (
<svg
width={Math.round((size * 3) / 17)}
height={size}
viewBox="0 0 3 17"
fill="none"
xmlns="http://www.w3.org/2000/svg"
{...props}
>
<path
d="M0.5 0.5V16.5M2.5 0.5V16.5"
stroke="currentColor"
strokeLinecap="round"
/>
</svg>
);
export default SvgHandle;

View File

@@ -77,6 +77,7 @@ export { default as SvgFolderPartialOpen } from "@opal/icons/folder-partial-open
export { default as SvgFolderPlus } from "@opal/icons/folder-plus";
export { default as SvgGemini } from "@opal/icons/gemini";
export { default as SvgGlobe } from "@opal/icons/globe";
export { default as SvgHandle } from "@opal/icons/handle";
export { default as SvgHardDrive } from "@opal/icons/hard-drive";
export { default as SvgHashSmall } from "@opal/icons/hash-small";
export { default as SvgHash } from "@opal/icons/hash";
@@ -143,6 +144,7 @@ export { default as SvgSlack } from "@opal/icons/slack";
export { default as SvgSlash } from "@opal/icons/slash";
export { default as SvgSliders } from "@opal/icons/sliders";
export { default as SvgSlidersSmall } from "@opal/icons/sliders-small";
export { default as SvgSort } from "@opal/icons/sort";
export { default as SvgSparkle } from "@opal/icons/sparkle";
export { default as SvgStar } from "@opal/icons/star";
export { default as SvgStep1 } from "@opal/icons/step1";

View File

@@ -0,0 +1,27 @@
import type { IconProps } from "@opal/types";
const SvgSort = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 4.5H10M2 8H7M2 11.5H5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M12 5V12M12 12L14 10M12 12L10 10"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgSort;

View File

@@ -10,7 +10,7 @@ export default function Main() {
<SettingsLayouts.Header
icon={SvgMcp}
title="MCP Actions"
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your agents."
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your assistants."
separator
/>
<SettingsLayouts.Body>

View File

@@ -10,7 +10,7 @@ export default function Main() {
<SettingsLayouts.Header
icon={SvgActions}
title="OpenAPI Actions"
description="Connect OpenAPI servers to add custom actions and tools for your agents."
description="Connect OpenAPI servers to add custom actions and tools for your assistants."
separator
/>
<SettingsLayouts.Body>

View File

@@ -170,7 +170,7 @@ export function PersonasTable({
{deleteModalOpen && personaToDelete && (
<ConfirmationModalLayout
icon={SvgAlertCircle}
title="Delete Agent"
title="Delete Assistant"
onClose={closeDeleteModal}
submit={<Button onClick={handleDeletePersona}>Delete</Button>}
>
@@ -183,15 +183,15 @@ export function PersonasTable({
const isDefault = personaToToggleDefault.is_default_persona;
const title = isDefault
? "Remove Featured Agent"
: "Set Featured Agent";
? "Remove Featured Assistant"
: "Set Featured Assistant";
const buttonText = isDefault ? "Remove Feature" : "Set as Featured";
const text = isDefault
? `Are you sure you want to remove the featured status of ${personaToToggleDefault.name}?`
: `Are you sure you want to set the featured status of ${personaToToggleDefault.name}?`;
const additionalText = isDefault
? `Removing "${personaToToggleDefault.name}" as a featured agent will not affect its visibility or accessibility.`
: `Setting "${personaToToggleDefault.name}" as a featured agent will make it public and visible to all users. This action cannot be undone.`;
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`;
return (
<ConfirmationModalLayout
@@ -217,7 +217,7 @@ export function PersonasTable({
"Name",
"Description",
"Type",
"Featured Agent",
"Featured Assistant",
"Is Visible",
"Delete",
]}

View File

@@ -47,8 +47,8 @@ function MainContent({
return (
<div>
<Text className="mb-2">
Agents are a way to build custom search/question-answering experiences
for different use cases.
Assistants are a way to build custom search/question-answering
experiences for different use cases.
</Text>
<Text className="mt-2">They allow you to customize:</Text>
<div className="text-sm">
@@ -63,21 +63,21 @@ function MainContent({
<div>
<Separator />
<Title>Create an Agent</Title>
<Title>Create an Assistant</Title>
<CreateButton href="/app/agents/create?admin=true">
New Agent
New Assistant
</CreateButton>
<Separator />
<Title>Existing Agents</Title>
<Title>Existing Assistants</Title>
{totalItems > 0 ? (
<>
<SubLabel>
Agents will be displayed as options on the Chat / Search
interfaces in the order they are displayed below. Agents marked as
hidden will not be displayed. Editable agents are shown at the
top.
Assistants will be displayed as options on the Chat / Search
interfaces in the order they are displayed below. Assistants
marked as hidden will not be displayed. Editable assistants are
shown at the top.
</SubLabel>
<PersonasTable
personas={customPersonas}
@@ -96,21 +96,21 @@ function MainContent({
) : (
<div className="mt-6 p-8 border border-border rounded-lg bg-background-weak text-center">
<Text className="text-lg font-medium mb-2">
No custom agents yet
No custom assistants yet
</Text>
<Text className="text-subtle mb-3">
Create your first agent to:
Create your first assistant to:
</Text>
<ul className="text-subtle text-sm list-disc text-left inline-block mb-3">
<li>Build department-specific knowledge bases</li>
<li>Create specialized research agents</li>
<li>Create specialized research assistants</li>
<li>Set up compliance and policy advisors</li>
</ul>
<Text className="text-subtle text-sm mb-4">
...and so much more!
</Text>
<CreateButton href="/app/agents/create?admin=true">
Create Your First Agent
Create Your First Assistant
</CreateButton>
</div>
)}
@@ -128,13 +128,13 @@ export default function Page() {
return (
<>
<AdminPageTitle icon={SvgOnyxOctagon} title="Agents" />
<AdminPageTitle icon={SvgOnyxOctagon} title="Assistants" />
{isLoading && <ThreeDotsLoader />}
{error && (
<ErrorCallout
errorTitle="Failed to load agents"
errorTitle="Failed to load assistants"
errorMsg={
error?.info?.message ||
error?.info?.detail ||

View File

@@ -156,7 +156,7 @@ export const SlackChannelConfigCreationForm = ({
is: "assistant",
then: (schema) =>
schema.required(
"An agent is required when using the 'Agent' knowledge source"
"A persona is required when using the'Assistant' knowledge source"
),
}),
standard_answer_categories: Yup.array(),

View File

@@ -224,14 +224,14 @@ export function SlackChannelConfigFormFields({
<RadioGroupItemField
value="assistant"
id="assistant"
label="Search Agent"
label="Search Assistant"
sublabel="Control both the documents and the prompt to use for answering questions"
/>
<RadioGroupItemField
value="non_search_assistant"
id="non_search_assistant"
label="Non-Search Agent"
sublabel="Chat with an agent that does not use documents"
label="Non-Search Assistant"
sublabel="Chat with an assistant that does not use documents"
/>
</RadioGroup>
</div>
@@ -327,15 +327,15 @@ export function SlackChannelConfigFormFields({
<div className="mt-4">
<SubLabel>
<>
Select the search-enabled agent OnyxBot will use while answering
questions in Slack.
Select the search-enabled assistant OnyxBot will use while
answering questions in Slack.
{syncEnabledAssistants.length > 0 && (
<>
<br />
<span className="text-sm text-text-dark/80">
Note: Some of your agents have auto-synced connectors in
their document sets. You cannot select these agents as
they will not be able to answer questions in Slack.{" "}
Note: Some of your assistants have auto-synced connectors
in their document sets. You cannot select these assistants
as they will not be able to answer questions in Slack.{" "}
<button
type="button"
onClick={() =>
@@ -349,7 +349,7 @@ export function SlackChannelConfigFormFields({
{viewSyncEnabledAssistants
? "Hide un-selectable "
: "View all "}
agents
assistants
</button>
</span>
</>
@@ -367,7 +367,7 @@ export function SlackChannelConfigFormFields({
{viewSyncEnabledAssistants && syncEnabledAssistants.length > 0 && (
<div className="mt-4">
<p className="text-sm text-text-dark/80">
Un-selectable agents:
Un-selectable assistants:
</p>
<div className="mb-3 mt-2 flex gap-2 flex-wrap text-sm">
{syncEnabledAssistants.map(
@@ -394,15 +394,15 @@ export function SlackChannelConfigFormFields({
<div className="mt-4">
<SubLabel>
<>
Select the non-search agent OnyxBot will use while answering
Select the non-search assistant OnyxBot will use while answering
questions in Slack.
{syncEnabledAssistants.length > 0 && (
<>
<br />
<span className="text-sm text-text-dark/80">
Note: Some of your agents have auto-synced connectors in
their document sets. You cannot select these agents as
they will not be able to answer questions in Slack.{" "}
Note: Some of your assistants have auto-synced connectors
in their document sets. You cannot select these assistants
as they will not be able to answer questions in Slack.{" "}
<button
type="button"
onClick={() =>
@@ -416,7 +416,7 @@ export function SlackChannelConfigFormFields({
{viewSyncEnabledAssistants
? "Hide un-selectable "
: "View all "}
agents
assistants
</button>
</span>
</>
@@ -524,7 +524,7 @@ export function SlackChannelConfigFormFields({
name="is_ephemeral"
label="Respond to user in a private (ephemeral) message"
tooltip="If set, OnyxBot will respond only to the user in a private (ephemeral) message. If you also
chose 'Search' Agent above, selecting this option will make documents that are private to the user
chose 'Search' Assistant above, selecting this option will make documents that are private to the user
available for their queries."
/>

View File

@@ -1,7 +0,0 @@
"use client";
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
export default function Page() {
return <CodeInterpreterPage />;
}

View File

@@ -39,10 +39,10 @@ export function AdvancedOptions({
agents={agents}
isLoading={agentsLoading}
error={agentsError}
label="Agent Whitelist"
subtext="Restrict this provider to specific agents."
label="Assistant Whitelist"
subtext="Restrict this provider to specific assistants."
disabled={formikProps.values.is_public}
disabledMessage="This LLM Provider is public and available to all agents."
disabledMessage="This LLM Provider is public and available to all assistants."
/>
</div>
</>

View File

@@ -299,11 +299,11 @@ export default function Page({ params }: Props) {
});
refreshGuild();
toast.success(
personaId ? "Default agent updated" : "Default agent cleared"
personaId ? "Default assistant updated" : "Default assistant cleared"
);
} catch (err) {
toast.error(
err instanceof Error ? err.message : "Failed to update agent"
err instanceof Error ? err.message : "Failed to update assistant"
);
} finally {
setIsUpdating(false);
@@ -355,7 +355,7 @@ export default function Page({ params }: Props) {
<InputSelect.Trigger placeholder="Select agent" />
<InputSelect.Content>
<InputSelect.Item value="default">
Default Agent
Default Assistant
</InputSelect.Item>
{personas.map((persona) => (
<InputSelect.Item

View File

@@ -47,8 +47,6 @@ export interface RendererResult {
// Whether this renderer supports collapsible mode (collapse button shown only when true)
supportsCollapsible?: boolean;
/** Whether the step should remain collapsible even in single-step timelines */
alwaysCollapsible?: boolean;
/** Whether the result should be wrapped by timeline UI or rendered as-is */
timelineLayout?: TimelineLayout;
}

View File

@@ -50,9 +50,7 @@ export function TimelineStepComposer({
header={result.status}
isExpanded={result.isExpanded}
onToggle={result.onToggle}
collapsible={
collapsible && (!isSingleStep || !!result.alwaysCollapsible)
}
collapsible={collapsible && !isSingleStep}
supportsCollapsible={result.supportsCollapsible}
isLastStep={index === results.length - 1 && isLastStep}
isFirstStep={index === 0 && isFirstStep}

View File

@@ -54,7 +54,7 @@ export function TimelineRow({
isHover={isHover}
/>
)}
<div className="flex-1 min-w-0">{children}</div>
<div className="flex-1">{children}</div>
</div>
);
}

View File

@@ -138,7 +138,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
{stdout && (
<div className="rounded-md bg-background-neutral-02 p-3">
<div className="text-xs font-semibold mb-1 text-text-03">Output:</div>
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01 overflow-x-auto">
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01">
{stdout}
</pre>
</div>
@@ -150,7 +150,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
<div className="text-xs font-semibold mb-1 text-status-error-05">
Error:
</div>
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05 overflow-x-auto">
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05">
{stderr}
</pre>
</div>
@@ -181,7 +181,6 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
status,
content,
supportsCollapsible: true,
alwaysCollapsible: true,
},
]);
}
@@ -192,7 +191,6 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
icon: SvgTerminal,
status,
supportsCollapsible: true,
alwaysCollapsible: true,
content: (
<FadingEdgeContainer
direction="bottom"

View File

@@ -427,7 +427,7 @@ export const GroupDisplay = ({
<Separator />
<h2 className="text-xl font-bold mt-8 mb-2">Agents</h2>
<h2 className="text-xl font-bold mt-8 mb-2">Assistants</h2>
<div>
{userGroup.document_sets.length > 0 ? (
@@ -445,7 +445,7 @@ export const GroupDisplay = ({
</div>
) : (
<>
<Text>No Agents in this group...</Text>
<Text>No Assistants in this group...</Text>
</>
)}
</div>

View File

@@ -152,14 +152,14 @@ export function PersonaMessagesChart({
} else if (selectedPersonaId === undefined) {
content = (
<div className="h-80 text-text-500 flex flex-col">
<p className="m-auto">Select an agent to view analytics</p>
<p className="m-auto">Select an assistant to view analytics</p>
</div>
);
} else if (!personaMessagesData?.length) {
content = (
<div className="h-80 text-text-500 flex flex-col">
<p className="m-auto">
No data found for selected agent in the specified time range
No data found for selected assistant in the specified time range
</p>
</div>
);
@@ -178,9 +178,11 @@ export function PersonaMessagesChart({
return (
<CardSection className="mt-8">
<Title>Agent Analytics</Title>
<Title>Assistant Analytics</Title>
<div className="flex flex-col gap-4">
<Text>Messages and unique users per day for the selected agent</Text>
<Text>
Messages and unique users per day for the selected assistant
</Text>
<div className="flex items-center gap-4">
<Select
value={selectedPersonaId?.toString() ?? ""}
@@ -189,14 +191,14 @@ export function PersonaMessagesChart({
}}
>
<SelectTrigger className="flex w-full max-w-xs">
<SelectValue placeholder="Select an agent to display" />
<SelectValue placeholder="Select an assistant to display" />
</SelectTrigger>
<SelectContent>
<div className="flex items-center px-2 pb-2 sticky top-0 bg-background border-b">
<Search className="h-4 w-4 mr-2 shrink-0 opacity-50" />
<input
className="flex h-8 w-full rounded-sm bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50"
placeholder="Search agents..."
placeholder="Search assistants..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
onClick={(e) => e.stopPropagation()}

View File

@@ -146,7 +146,7 @@ export function AssistantStats({ assistantId }: { assistantId: number }) {
return (
<Card className="w-full">
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<p className="text-base font-normal text-2xl">Agent Analytics</p>
<p className="text-base font-normal text-2xl">Assistant Analytics</p>
<AdminDateRangeSelector
value={dateRange}
onValueChange={setDateRange}

View File

@@ -72,7 +72,6 @@ export function ClientLayout({
enableEnterpriseSS={enableEnterprise}
/>
<div
data-main-container
className={cn(
"flex flex-1 flex-col min-w-0 min-h-0 overflow-y-auto",
!hasOwnLayout && "py-10 px-4 md:px-12"

View File

@@ -41,14 +41,8 @@ export default function AccessRestricted() {
const [error, setError] = useState<string | null>(null);
const { data: license } = useLicense();
const hadPreviousLicense = license?.has_license === true;
const showRenewalMessage = NEXT_PUBLIC_CLOUD_ENABLED || hadPreviousLicense;
const initialModalMessage = showRenewalMessage
? NEXT_PUBLIC_CLOUD_ENABLED
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
: "Your access to Onyx has been temporarily suspended due to a lapse in your license."
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated.";
// Distinguish between "never had a license" vs "license lapsed"
const hasLicenseLapsed = license?.has_license === true;
const handleResubscribe = async () => {
setIsLoading(true);
@@ -78,7 +72,11 @@ export default function AccessRestricted() {
<SvgLock className="stroke-status-error-05 w-[1.5rem] h-[1.5rem]" />
</div>
<Text text03>{initialModalMessage}</Text>
<Text text03>
{hasLicenseLapsed
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated."}
</Text>
{NEXT_PUBLIC_CLOUD_ENABLED ? (
<>
@@ -113,7 +111,7 @@ export default function AccessRestricted() {
) : (
<>
<Text text03>
{hadPreviousLicense
{hasLicenseLapsed
? "To reinstate your access and continue using Onyx, please contact your system administrator to renew your license."
: "To get started, please contact your system administrator to obtain an Enterprise license."}
</Text>
@@ -123,8 +121,8 @@ export default function AccessRestricted() {
<Link className={linkClassName} href="/admin/billing">
Admin Billing
</Link>{" "}
page to {hadPreviousLicense ? "renew" : "activate"} your license,
sign up through Stripe or reach out to{" "}
page to {hasLicenseLapsed ? "renew" : "activate"} your license, sign
up through Stripe or reach out to{" "}
<a className={linkClassName} href="mailto:support@onyx.app">
support@onyx.app
</a>

View File

@@ -12,17 +12,17 @@ export default function NoAssistantModal() {
return (
<Modal open>
<Modal.Content width="sm" height="sm">
<Modal.Header icon={SvgUser} title="No Agent Available" />
<Modal.Header icon={SvgUser} title="No Assistant Available" />
<Modal.Body>
<Text as="p">
You currently have no agent configured. To use this feature, you
You currently have no assistant configured. To use this feature, you
need to take action.
</Text>
{isAdmin ? (
<>
<Text as="p">
As an administrator, you can create a new agent by visiting the
admin panel.
As an administrator, you can create a new assistant by visiting
the admin panel.
</Text>
<Button className="w-full" href="/admin/assistants">
Go to Admin Panel
@@ -30,7 +30,8 @@ export default function NoAssistantModal() {
</>
) : (
<Text as="p">
Please contact your administrator to configure an agent for you.
Please contact your administrator to configure an assistant for
you.
</Text>
)}
</Modal.Body>

View File

@@ -1,44 +0,0 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
const HEALTH_ENDPOINT = "/api/admin/code-interpreter/health";
const STATUS_ENDPOINT = "/api/admin/code-interpreter";
interface CodeInterpreterHealth {
healthy: boolean;
}
interface CodeInterpreterStatus {
enabled: boolean;
}
export default function useCodeInterpreter() {
const {
data: healthData,
error: healthError,
isLoading: isHealthLoading,
mutate: refetchHealth,
} = useSWR<CodeInterpreterHealth>(HEALTH_ENDPOINT, errorHandlingFetcher, {
refreshInterval: 30000,
});
const {
data: statusData,
error: statusError,
isLoading: isStatusLoading,
mutate: refetchStatus,
} = useSWR<CodeInterpreterStatus>(STATUS_ENDPOINT, errorHandlingFetcher);
function refetch() {
refetchHealth();
refetchStatus();
}
return {
isHealthy: healthData?.healthy ?? false,
isEnabled: statusData?.enabled ?? false,
isLoading: isHealthLoading || isStatusLoading,
error: healthError || statusError,
refetch,
};
}

View File

@@ -1,15 +0,0 @@
const UPDATE_ENDPOINT = "/api/admin/code-interpreter";
interface CodeInterpreterUpdateRequest {
enabled: boolean;
}
export async function updateCodeInterpreter(
request: CodeInterpreterUpdateRequest
): Promise<Response> {
return fetch(UPDATE_ENDPOINT, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(request),
});
}

View File

@@ -131,8 +131,7 @@ export async function updateAgentSharedStatus(
userIds: string[],
groupIds: number[],
isPublic: boolean | undefined,
isPaidEnterpriseFeaturesEnabled: boolean,
labelIds?: number[]
isPaidEnterpriseFeaturesEnabled: boolean
): Promise<null | string> {
// MIT versions should not send group_ids - warn if caller provided non-empty groups
if (!isPaidEnterpriseFeaturesEnabled && groupIds.length > 0) {
@@ -153,7 +152,6 @@ export async function updateAgentSharedStatus(
// Only include group_ids for enterprise versions
group_ids: isPaidEnterpriseFeaturesEnabled ? groupIds : undefined,
is_public: isPublic,
label_ids: labelIds,
}),
});
@@ -168,63 +166,3 @@ export async function updateAgentSharedStatus(
return "Network error. Please check your connection and try again.";
}
}
/**
* Updates the labels assigned to an agent via the share endpoint.
*
* @param agentId - The ID of the agent to update
* @param labelIds - Array of label IDs to assign to the agent
* @returns null on success, or an error message string on failure
*/
export async function updateAgentLabels(
agentId: number,
labelIds: number[]
): Promise<string | null> {
try {
const response = await fetch(`/api/persona/${agentId}/share`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ label_ids: labelIds }),
});
if (response.ok) {
return null;
}
const errorMessage = (await response.json()).detail || "Unknown error";
return errorMessage;
} catch (error) {
console.error("updateAgentLabels: Network error", error);
return "Network error. Please check your connection and try again.";
}
}
/**
* Updates the featured (default) status of an agent.
*
* @param agentId - The ID of the agent to update
* @param isFeatured - Whether the agent should be featured
* @returns null on success, or an error message string on failure
*/
export async function updateAgentFeaturedStatus(
agentId: number,
isFeatured: boolean
): Promise<string | null> {
try {
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ is_default_persona: isFeatured }),
});
if (response.ok) {
return null;
}
const errorMessage = (await response.json()).detail || "Unknown error";
return errorMessage;
} catch (error) {
console.error("updateAgentFeaturedStatus: Network error", error);
return "Network error. Please check your connection and try again.";
}
}

View File

@@ -257,27 +257,19 @@ export const useLabels = () => {
return mutate("/api/persona/labels");
};
const createLabel = async (name: string): Promise<PersonaLabel | null> => {
const createLabel = async (name: string) => {
const response = await fetch("/api/persona/labels", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ name }),
});
if (!response.ok) {
return null;
if (response.ok) {
const newLabel = await response.json();
mutate("/api/persona/labels", [...(labels || []), newLabel], false);
}
const newLabel: PersonaLabel = await response.json();
mutate(
"/api/persona/labels",
(currentLabels: PersonaLabel[] | undefined) => [
...(currentLabels || []),
newLabel,
],
false
);
return newLabel;
return response;
};
const updateLabel = async (id: number, name: string) => {

View File

@@ -1,51 +1,29 @@
import Text from "@/refresh-components/texts/Text";
import { SvgX } from "@opal/icons";
import { Button } from "@opal/components";
import type { IconProps } from "@opal/types";
export interface ChipProps {
children?: string;
icon?: React.FunctionComponent<IconProps>;
onRemove?: () => void;
smallLabel?: boolean;
}
/**
* A simple chip/tag component for displaying metadata.
* Supports an optional remove button via the `onRemove` prop.
*
* @example
* ```tsx
* <Chip>Tag Name</Chip>
* <Chip icon={SvgUser}>John Doe</Chip>
* <Chip onRemove={() => removeTag(id)}>Removable</Chip>
* ```
*/
export default function Chip({
children,
icon: Icon,
onRemove,
smallLabel = true,
}: ChipProps) {
export default function Chip({ children, icon: Icon }: ChipProps) {
return (
<div className="flex items-center gap-1 px-1.5 py-0.5 rounded-08 bg-background-tint-02">
{Icon && <Icon size={12} className="text-text-03" />}
{children && (
<Text figureSmallLabel={smallLabel} text03>
<Text figureSmallLabel text03>
{children}
</Text>
)}
{onRemove && (
<Button
onClick={(e) => {
e.stopPropagation();
onRemove();
}}
prominence="tertiary"
icon={SvgX}
size="xs"
/>
)}
</div>
);
}

View File

@@ -39,9 +39,9 @@ const useTabsContext = () => {
*
* Contained (default):
* ┌─────────────────────────────────────────────────┐
* │ ┌──────────┐ ╔══════════╗ ┌──────────┐
* │ │ Tab 1 │ ║ Tab 2 ║ │ Tab 3 │ │ ← gray background
* │ └──────────┘ ╚══════════╝ └──────────┘
* │ ┌──────────┐ ╔══════════╗ ┌──────────┐ │
* │ │ Tab 1 │ ║ Tab 2 ║ │ Tab 3 │ │ ← gray background
* │ └──────────┘ ╚══════════╝ └──────────┘ │
* └─────────────────────────────────────────────────┘
* ↑ active tab (white bg, shadow)
*
@@ -49,7 +49,7 @@ const useTabsContext = () => {
* Tab 1 Tab 2 Tab 3 [Action]
* ╔═════╗
* ║ ║ ↑ optional rightContent
* ────────────╨═════╨─────────────────────────────
* ────────────╨═════╨─────────────────────────────
* ↑ sliding indicator under active tab
*
* @example

View File

@@ -1,125 +0,0 @@
"use client";
import * as React from "react";
import { cn } from "@/lib/utils";
import Chip from "@/refresh-components/Chip";
import {
innerClasses,
textClasses,
Variants,
wrapperClasses,
} from "@/refresh-components/inputs/styles";
import type { IconProps } from "@opal/types";
export interface ChipItem {
id: string;
label: string;
}
export interface InputChipFieldProps {
chips: ChipItem[];
onRemoveChip: (id: string) => void;
onAdd: (value: string) => void;
value: string;
onChange: (value: string) => void;
placeholder?: string;
disabled?: boolean;
variant?: Variants;
icon?: React.FunctionComponent<IconProps>;
className?: string;
}
/**
* A tag/chip input field that renders chips inline alongside a text input.
*
* Pressing Enter adds a chip via `onAdd`. Pressing Backspace on an empty
* input removes the last chip. Each chip has a remove button.
*
* @example
* ```tsx
* <InputChipField
* chips={[{ id: "1", label: "Search" }]}
* onRemoveChip={(id) => remove(id)}
* onAdd={(value) => add(value)}
* value={inputValue}
* onChange={setInputValue}
* placeholder="Add labels..."
* icon={SvgTag}
* />
* ```
*/
function InputChipField({
chips,
onRemoveChip,
onAdd,
value,
onChange,
placeholder,
disabled = false,
variant = "primary",
icon: Icon,
className,
}: InputChipFieldProps) {
const inputRef = React.useRef<HTMLInputElement>(null);
function handleKeyDown(e: React.KeyboardEvent<HTMLInputElement>) {
if (disabled) {
return;
}
if (e.key === "Enter") {
e.preventDefault();
e.stopPropagation();
const trimmed = value.trim();
if (trimmed) {
onAdd(trimmed);
}
}
if (e.key === "Backspace" && value === "") {
const lastChip = chips[chips.length - 1];
if (lastChip) {
onRemoveChip(lastChip.id);
}
}
}
return (
<div
className={cn(
"flex flex-row items-center flex-wrap gap-1 p-1.5 rounded-08 cursor-text w-full",
wrapperClasses[variant],
className
)}
onClick={() => inputRef.current?.focus()}
>
{Icon && <Icon size={16} className="text-text-04 shrink-0" />}
{chips.map((chip) => (
<Chip
key={chip.id}
onRemove={disabled ? undefined : () => onRemoveChip(chip.id)}
smallLabel={false}
>
{chip.label}
</Chip>
))}
<input
ref={inputRef}
type="text"
disabled={disabled}
value={value}
onChange={(e) => onChange(e.target.value)}
onKeyDown={handleKeyDown}
placeholder={chips.length === 0 ? placeholder : undefined}
className={cn(
"flex-1 min-w-[80px] h-[1.5rem] bg-transparent p-0.5 focus:outline-none",
innerClasses[variant],
textClasses[variant]
)}
/>
</div>
);
}
export default InputChipField;

View File

@@ -12,8 +12,6 @@ import {
SvgX,
SvgXOctagon,
} from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
const containerClasses = {
flash: {
default: {
@@ -220,7 +218,6 @@ export interface MessageProps extends React.HTMLAttributes<HTMLDivElement> {
// Features:
icon?: boolean;
iconComponent?: IconFunctionComponent;
actions?: boolean | string;
close?: boolean;
@@ -247,7 +244,6 @@ function MessageInner(
description,
icon = true,
iconComponent,
actions,
close = true,
@@ -283,9 +279,8 @@ function MessageInner(
const textClass = useMemo(() => textClasses[type].text, [type]);
const descriptionClass = useMemo(() => textClasses[type].description, [type]);
const IconComponent = iconComponent
? iconComponent
: level === "success"
const IconComponent =
level === "success"
? SvgCheckCircle
: level === "warning"
? SvgAlertTriangle

View File

@@ -179,7 +179,7 @@ export default function ActionLineItem({
)}
{isSearchToolAndNotInProject && (
<Button
<IconButton
icon={
isSearchToolWithNoConnectors ? SvgSettings : SvgChevronRight
}
@@ -188,8 +188,11 @@ export default function ActionLineItem({
router.push("/admin/add-connector");
else onSourceManagementOpen?.();
})}
prominence="tertiary"
size="sm"
internal
className={cn(
isSearchToolWithNoConnectors &&
"invisible group-hover/LineItem:visible"
)}
tooltip={
isSearchToolWithNoConnectors
? "Add Connectors"

View File

@@ -0,0 +1,361 @@
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgChevronLeft, SvgChevronRight } from "@opal/icons";
type PaginationSize = "lg" | "md" | "sm";
/**
* Minimal page navigation showing `currentPage / totalPages` with prev/next arrows.
* Use when you only need simple forward/backward navigation.
*/
interface SimplePaginationProps {
type: "simple";
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `true`, displays the word "pages" after the page indicator. */
showUnits?: boolean;
/** When `false`, hides the page indicator between the prev/next arrows. Defaults to `true`. */
showPageIndicator?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. */
size?: PaginationSize;
className?: string;
}
/**
* Item-count pagination showing `currentItems of totalItems` with optional page
* controls and a "Go to" button. Use inside table footers that need to communicate
* how many items the user is viewing.
*/
interface CountPaginationProps {
type: "count";
/** Number of items displayed per page. Used to compute the visible range. */
pageSize: number;
/** Total number of items across all pages. */
totalItems: number;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `false`, hides the page number between the prev/next arrows (arrows still visible). Defaults to `true`. */
showPageIndicator?: boolean;
/** When `true`, renders a "Go to" button. Requires `onGoTo`. */
showGoTo?: boolean;
/** Callback invoked when the "Go to" button is clicked. */
onGoTo?: () => void;
/** When `true`, displays the word "items" after the total count. */
showUnits?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. */
size?: PaginationSize;
className?: string;
}
/**
* Numbered page-list pagination with clickable page buttons and ellipsis
* truncation for large page counts. Does not support `"sm"` size.
*/
interface ListPaginationProps {
type: "list";
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `false`, hides the page buttons between the prev/next arrows. Defaults to `true`. */
showPageIndicator?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. Only `"lg"` and `"md"` are supported. */
size?: Exclude<PaginationSize, "sm">;
className?: string;
}
/**
* Discriminated union of all pagination variants.
* Use the `type` prop to select between `"simple"`, `"count"`, and `"list"`.
*/
export type PaginationProps =
| SimplePaginationProps
| CountPaginationProps
| ListPaginationProps;
function getPageNumbers(currentPage: number, totalPages: number) {
const pages: (number | string)[] = [];
const maxPagesToShow = 7;
if (totalPages <= maxPagesToShow) {
for (let i = 1; i <= totalPages; i++) {
pages.push(i);
}
} else {
pages.push(1);
let startPage = Math.max(2, currentPage - 1);
let endPage = Math.min(totalPages - 1, currentPage + 1);
if (currentPage <= 3) {
endPage = 5;
} else if (currentPage >= totalPages - 2) {
startPage = totalPages - 4;
}
if (startPage > 2) {
pages.push("start-ellipsis");
}
for (let i = startPage; i <= endPage; i++) {
pages.push(i);
}
if (endPage < totalPages - 1) {
pages.push("end-ellipsis");
}
pages.push(totalPages);
}
return pages;
}
function sizedTextProps(isSmall: boolean, variant: "mono" | "muted") {
if (variant === "mono") {
return isSmall ? { secondaryMono: true } : { mainUiMono: true };
}
return isSmall ? { secondaryBody: true } : { mainUiMuted: true };
}
interface NavButtonsProps {
currentPage: number;
totalPages: number;
onPageChange: (page: number) => void;
size: PaginationSize;
children?: React.ReactNode;
}
function NavButtons({
currentPage,
totalPages,
onPageChange,
size,
children,
}: NavButtonsProps) {
return (
<>
<Button
icon={SvgChevronLeft}
onClick={() => onPageChange(currentPage - 1)}
disabled={currentPage <= 1}
size={size}
prominence="tertiary"
/>
{children}
<Button
icon={SvgChevronRight}
onClick={() => onPageChange(currentPage + 1)}
disabled={currentPage >= totalPages}
size={size}
prominence="tertiary"
/>
</>
);
}
/**
* Table pagination component with three variants: `simple`, `count`, and `list`.
* Pass the `type` prop to select the variant, and the component will render the
* appropriate UI.
*/
export default function Pagination(props: PaginationProps) {
switch (props.type) {
case "simple":
return <SimplePaginationInner {...props} />;
case "count":
return <CountPaginationInner {...props} />;
case "list":
return <ListPaginationInner {...props} />;
}
}
function SimplePaginationInner({
currentPage,
totalPages,
onPageChange,
showUnits,
showPageIndicator = true,
size = "lg",
className,
}: SimplePaginationProps) {
const isSmall = size === "sm";
return (
<div className={cn("flex items-center gap-1", className)}>
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentPage}
<Text as="span" {...sizedTextProps(isSmall, "muted")} text03>
/
</Text>
{totalPages}
</Text>
{showUnits && (
<Text {...sizedTextProps(isSmall, "muted")} text03>
pages
</Text>
)}
</>
)}
</NavButtons>
</div>
);
}
function CountPaginationInner({
pageSize,
totalItems,
currentPage,
totalPages,
onPageChange,
showPageIndicator = true,
showGoTo,
onGoTo,
showUnits,
size = "lg",
className,
}: CountPaginationProps) {
const isSmall = size === "sm";
const rangeStart = (currentPage - 1) * pageSize + 1;
const rangeEnd = Math.min(currentPage * pageSize, totalItems);
const currentItems = `${rangeStart}~${rangeEnd}`;
return (
<div className={cn("flex items-center gap-1", className)}>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentItems}
</Text>
<Text {...sizedTextProps(isSmall, "muted")} text03>
of
</Text>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{totalItems}
</Text>
{showUnits && (
<Text {...sizedTextProps(isSmall, "muted")} text03>
items
</Text>
)}
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentPage}
</Text>
)}
</NavButtons>
{showGoTo && onGoTo && (
<Button onClick={onGoTo} size={size} prominence="tertiary">
Go to
</Button>
)}
</div>
);
}
function ListPaginationInner({
currentPage,
totalPages,
onPageChange,
showPageIndicator = true,
size = "lg",
className,
}: ListPaginationProps) {
const pageNumbers = getPageNumbers(currentPage, totalPages);
return (
<div className={cn("flex items-center gap-1", className)}>
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<div className="flex items-center">
{pageNumbers.map((page) => {
if (typeof page === "string") {
return (
<Text
key={page}
mainUiMuted={size === "lg"}
secondaryBody={size === "md"}
text03
>
...
</Text>
);
}
const pageNum = page as number;
const isActive = pageNum === currentPage;
return (
<Button
key={pageNum}
onClick={() => onPageChange(pageNum)}
size={size}
prominence="tertiary"
transient={isActive}
icon={({ className: iconClassName }) => (
<div
className={cn(
iconClassName,
"flex flex-col justify-center"
)}
>
{size === "lg" ? (
<Text
mainUiBody={isActive}
mainUiMuted={!isActive}
text04={isActive}
text02={!isActive}
>
{pageNum}
</Text>
) : (
<Text
secondaryAction={isActive}
secondaryBody={!isActive}
text04={isActive}
text02={!isActive}
>
{pageNum}
</Text>
)}
</div>
)}
/>
);
})}
</div>
)}
</NavButtons>
</div>
);
}

View File

@@ -0,0 +1,116 @@
import { cn } from "@/lib/utils";
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import { SvgHandle } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
type SortDirection = "none" | "ascending" | "descending";
/**
* A table header cell with optional sort controls and a resize handle indicator.
* Renders as a `<th>` element with Figma-matched typography and spacing.
*/
interface TableHeadProps {
/** Header label content. */
children: React.ReactNode;
/** Current sort state. When omitted, no sort button is shown. */
sorted?: SortDirection;
/** Called when the sort button is clicked. Required to show the sort button. */
onSort?: () => void;
/** When `true`, renders a thin resize handle on the right edge. */
resizable?: boolean;
/** Override the sort icon for this column. Receives the current sort state and
* returns the icon component to render. Falls back to the built-in icons. */
icon?: (sorted: SortDirection) => IconFunctionComponent;
/** Text alignment for the column. Defaults to `"left"`. */
alignment?: "left" | "center" | "right";
/** Cell density. `"small"` uses tighter padding for denser layouts. */
size?: "regular" | "small";
/** Additional classes on the outer `<th>` element. */
className?: string;
}
/**
* Table header cell primitive. Displays a column label with optional sort
* functionality and a resize handle indicator.
*/
const alignmentThClass = {
left: "text-left",
center: "text-center",
right: "text-right",
} as const;
const alignmentFlexClass = {
left: "justify-start",
center: "justify-center",
right: "justify-end",
} as const;
export default function TableHead({
children,
sorted,
onSort,
icon: iconFn,
resizable,
alignment = "left",
size = "regular",
className,
}: TableHeadProps) {
const isSmall = size === "small";
const resolvedIcon = iconFn;
return (
<th
className={cn(
"group relative",
alignmentThClass[alignment],
isSmall ? "p-1.5" : "px-2 py-1",
"border-b border-transparent group-hover:border-border-03",
className
)}
>
<div
className={cn("flex items-center gap-1", alignmentFlexClass[alignment])}
>
<div className={isSmall ? "py-1" : "py-2"}>
<Text
mainUiAction={!isSmall}
secondaryAction={isSmall}
text04
className="truncate"
>
{children}
</Text>
</div>
<div
className={cn(
!isSmall && "py-1.5",
"opacity-0 group-hover:opacity-100 transition-opacity"
)}
>
{onSort && resolvedIcon && (
<Button
icon={resolvedIcon(sorted ?? "none")}
onClick={onSort}
tooltip="Sort"
tooltipSide="top"
prominence="internal"
size="sm"
/>
)}
</div>
</div>
{resizable && (
<div
className={cn(
"absolute right-0 top-0 flex h-full items-center",
"text-border-02",
"opacity-0 group-hover:opacity-100",
"cursor-col-resize"
)}
>
<SvgHandle size={22} className="stroke-border-02" />
</div>
)}
</th>
);
}

View File

@@ -0,0 +1,295 @@
"use client";
import { cn } from "@/lib/utils";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import Pagination from "@/refresh-components/table/Pagination";
import { SvgEye, SvgXCircle } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
type FooterSize = "regular" | "small";
type SelectionState = "none" | "partial" | "all";
/**
* Footer mode for tables with selectable rows.
* Displays a selection message on the left (with optional view/clear actions)
* and a `count`-type pagination on the right.
*/
interface FooterSelectionModeProps {
mode: "selection";
/** Whether the table supports selecting multiple rows. */
multiSelect: boolean;
/** Current selection state: `"none"`, `"partial"`, or `"all"`. */
selectionState: SelectionState;
/** Number of currently selected items. */
selectedCount: number;
/** When `true`, renders a qualifier checkbox on the far left. */
showQualifier?: boolean;
/** Controlled checked state for the qualifier checkbox. */
qualifierChecked?: boolean;
/** Called when the qualifier checkbox value changes. */
onQualifierChange?: (checked: boolean) => void;
/** If provided, renders a "View" icon button when items are selected. */
onView?: () => void;
/** If provided, renders a "Clear" icon button when items are selected. */
onClear?: () => void;
/** Number of items displayed per page. */
pageSize: number;
/** First item number in the current page (e.g. `1`). */
rangeStart: number;
/** Last item number in the current page (e.g. `25`). */
rangeEnd: number;
/** Total number of items across all pages. */
totalItems: number;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
size?: FooterSize;
className?: string;
}
/**
* Footer mode for read-only tables (no row selection).
* Displays "Showing X~Y of Z" on the left and a `list`-type pagination
* on the right.
*/
interface FooterSummaryModeProps {
mode: "summary";
/** Number of items displayed per page. */
pageSize: number;
/** First item number in the current page (e.g. `1`). */
rangeStart: number;
/** Last item number in the current page (e.g. `25`). */
rangeEnd: number;
/** Total number of items across all pages. */
totalItems: number;
/** When `true`, renders a qualifier checkbox on the far left. */
showQualifier?: boolean;
/** Controlled checked state for the qualifier checkbox. */
qualifierChecked?: boolean;
/** Called when the qualifier checkbox value changes. */
onQualifierChange?: (checked: boolean) => void;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
size?: FooterSize;
className?: string;
}
/**
* Discriminated union of footer modes.
* Use `mode: "selection"` for tables with selectable rows, or
* `mode: "summary"` for read-only tables.
*/
export type FooterProps = FooterSelectionModeProps | FooterSummaryModeProps;
function getSelectionMessage(
state: SelectionState,
multi: boolean,
count: number
): string {
if (state === "none") {
return multi ? "Select items to continue" : "Select an item to continue";
}
if (!multi) return "Item selected";
return `${count} items selected`;
}
/**
* Table footer combining status information on the left with pagination on the
* right. Use `mode: "selection"` for tables with selectable rows, or
* `mode: "summary"` for read-only tables.
*/
export default function Footer(props: FooterProps) {
const { size = "regular", className } = props;
const isSmall = size === "small";
return (
<div
className={cn(
"flex w-full items-center justify-between border-t border-border-01",
isSmall ? "min-h-[2.25rem]" : "min-h-[2.75rem]",
className
)}
>
{/* Left side */}
<div className="flex items-center gap-1 px-1">
{props.showQualifier && (
<div className="flex items-center px-1">
<Checkbox
checked={props.qualifierChecked}
indeterminate={
props.mode === "selection" && props.selectionState === "partial"
}
onCheckedChange={props.onQualifierChange}
/>
</div>
)}
{props.mode === "selection" ? (
<SelectionLeft
selectionState={props.selectionState}
multiSelect={props.multiSelect}
selectedCount={props.selectedCount}
onView={props.onView}
onClear={props.onClear}
isSmall={isSmall}
/>
) : (
<SummaryLeft
rangeStart={props.rangeStart}
rangeEnd={props.rangeEnd}
totalItems={props.totalItems}
isSmall={isSmall}
/>
)}
</div>
{/* Right side */}
<div className="flex items-center gap-2 px-1 py-2">
{props.mode === "selection" ? (
<Pagination
type="count"
pageSize={props.pageSize}
totalItems={props.totalItems}
currentPage={props.currentPage}
totalPages={props.totalPages}
onPageChange={props.onPageChange}
showUnits
size={isSmall ? "sm" : "md"}
/>
) : (
<Pagination
type="list"
currentPage={props.currentPage}
totalPages={props.totalPages}
onPageChange={props.onPageChange}
size={isSmall ? "md" : "lg"}
/>
)}
</div>
</div>
);
}
interface SelectionLeftProps {
selectionState: SelectionState;
multiSelect: boolean;
selectedCount: number;
onView?: () => void;
onClear?: () => void;
isSmall: boolean;
}
function SelectionLeft({
selectionState,
multiSelect,
selectedCount,
onView,
onClear,
isSmall,
}: SelectionLeftProps) {
const message = getSelectionMessage(
selectionState,
multiSelect,
selectedCount
);
const hasSelection = selectionState !== "none";
return (
<div className="flex flex-row gap-1 items-center justify-center w-fit flex-shrink-0 h-fit px-1">
{isSmall ? (
<Text
secondaryAction={hasSelection}
secondaryBody={!hasSelection}
text03
>
{message}
</Text>
) : (
<Text mainUiBody={hasSelection} mainUiMuted={!hasSelection} text03>
{message}
</Text>
)}
{hasSelection && (
<div className="flex flex-row items-center w-fit flex-shrink-0 h-fit">
{onView && (
<Button
icon={SvgEye}
onClick={onView}
tooltip="View"
size="md"
prominence="tertiary"
/>
)}
{onClear && (
<Button
icon={SvgXCircle}
onClick={onClear}
tooltip="Clear selection"
size="md"
prominence="tertiary"
/>
)}
</div>
)}
</div>
);
}
interface SummaryLeftProps {
rangeStart: number;
rangeEnd: number;
totalItems: number;
isSmall: boolean;
}
function SummaryLeft({
rangeStart,
rangeEnd,
totalItems,
isSmall,
}: SummaryLeftProps) {
return (
<Section
flexDirection="row"
gap={0.25}
alignItems="center"
width="fit"
height="fit"
>
{isSmall ? (
<Text secondaryBody text03>
Showing{" "}
<Text as="span" secondaryMono text03>
{rangeStart}~{rangeEnd}
</Text>{" "}
of{" "}
<Text as="span" secondaryMono text03>
{totalItems}
</Text>
</Text>
) : (
<Text mainUiMuted text03>
Showing{" "}
<Text as="span" mainUiMono text03>
{rangeStart}~{rangeEnd}
</Text>{" "}
of{" "}
<Text as="span" mainUiMono text03>
{totalItems}
</Text>
</Text>
)}
</Section>
);
}

View File

@@ -13,7 +13,6 @@ import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
import InputTypeInElementField from "@/refresh-components/form/InputTypeInElementField";
import InputDatePickerField from "@/refresh-components/form/InputDatePickerField";
import Message from "@/refresh-components/messages/Message";
import Separator from "@/refresh-components/Separator";
import * as InputLayouts from "@/layouts/input-layouts";
import { useFormikContext } from "formik";
@@ -57,7 +56,6 @@ import {
SvgLock,
SvgOnyxOctagon,
SvgSliders,
SvgUsers,
SvgTrash,
} from "@opal/icons";
import CustomAgentAvatar, {
@@ -88,7 +86,6 @@ import ShareAgentModal from "@/sections/modals/ShareAgentModal";
import AgentKnowledgePane from "@/sections/knowledge/AgentKnowledgePane";
import { ValidSources } from "@/lib/types";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { useUser } from "@/providers/UserProvider";
interface AgentIconEditorProps {
existingAgent?: FullPersona | null;
@@ -453,8 +450,6 @@ export default function AgentEditorPage({
const shareAgentModal = useCreateModal();
const deleteAgentModal = useCreateModal();
const settings = useSettingsContext();
const { isAdmin, isCurator } = useUser();
const canUpdateFeaturedStatus = isAdmin || isCurator;
const vectorDbEnabled = settings?.settings.vector_db_enabled !== false;
// LLM Model Selection
@@ -656,8 +651,6 @@ export default function AgentEditorPage({
shared_user_ids: existingAgent?.users?.map((user) => user.id) ?? [],
shared_group_ids: existingAgent?.groups ?? [],
is_public: existingAgent?.is_public ?? true,
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
is_default_persona: existingAgent?.is_default_persona ?? false,
};
const validationSchema = Yup.object().shape({
@@ -819,8 +812,8 @@ export default function AgentEditorPage({
uploaded_image_id: values.uploaded_image_id,
icon_name: values.icon_name,
search_start_date: values.knowledge_cutoff_date || null,
label_ids: values.label_ids,
is_default_persona: values.is_default_persona,
label_ids: null,
is_default_persona: false,
// display_priority: ...,
user_file_ids: values.enable_knowledge ? values.user_file_ids : [],
@@ -1002,10 +995,6 @@ export default function AgentEditorPage({
(fileId: string) =>
fileStatusMap.get(fileId) === UserFileStatus.PROCESSING
);
const isShared =
values.is_public ||
values.shared_user_ids.length > 0 ||
values.shared_group_ids.length > 0;
return (
<>
@@ -1065,20 +1054,10 @@ export default function AgentEditorPage({
userIds={values.shared_user_ids}
groupIds={values.shared_group_ids}
isPublic={values.is_public}
isFeatured={values.is_default_persona}
labelIds={values.label_ids}
onShare={(
userIds,
groupIds,
isPublic,
isFeatured,
labelIds
) => {
onShare={(userIds, groupIds, isPublic) => {
setFieldValue("shared_user_ids", userIds);
setFieldValue("shared_group_ids", groupIds);
setFieldValue("is_public", isPublic);
setFieldValue("is_default_persona", isFeatured);
setFieldValue("label_ids", labelIds);
shareAgentModal.toggle(false);
}}
/>
@@ -1381,36 +1360,17 @@ export default function AgentEditorPage({
<Card>
<InputLayouts.Horizontal
title="Share This Agent"
description="with other users, groups, or everyone in your organization."
description="Share this agent with other users, groups, or everyone in your organization."
center
>
<Button
secondary
leftIcon={isShared ? SvgUsers : SvgLock}
leftIcon={SvgLock}
onClick={() => shareAgentModal.toggle(true)}
>
Share
</Button>
</InputLayouts.Horizontal>
{canUpdateFeaturedStatus && (
<>
<InputLayouts.Horizontal
name="is_default_persona"
title="Feature This Agent"
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
>
<SwitchField name="is_default_persona" />
</InputLayouts.Horizontal>
{values.is_default_persona && !isShared && (
<Message
static
close={false}
className="w-full"
text="This agent is private to you and will only be featured for yourself."
/>
)}
</>
)}
</Card>
<Card>

View File

@@ -425,7 +425,7 @@ export default function AgentsNavigationPage() {
>
<SettingsLayouts.Header
icon={SvgOnyxOctagon}
title="Agents"
title="Agents & Assistants"
description="Customize AI behavior and knowledge for you and your team's use cases."
rightChildren={
<Button

View File

@@ -1,241 +0,0 @@
"use client";
import React, { useState } from "react";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Card, type CardProps } from "@/refresh-components/cards";
import {
SvgArrowExchange,
SvgCheckCircle,
SvgRefreshCw,
SvgTerminal,
SvgUnplug,
SvgXOctagon,
} from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import useCodeInterpreter from "@/hooks/useCodeInterpreter";
import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
import { ContentAction } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
interface CodeInterpreterCardProps {
variant?: CardProps["variant"];
title: string;
middleText?: string;
strikethrough?: boolean;
rightContent: React.ReactNode;
}
function CodeInterpreterCard({
variant,
title,
middleText,
strikethrough,
rightContent,
}: CodeInterpreterCardProps) {
return (
// TODO (@raunakab): Allow Content to accept strikethrough and middleText
<Card variant={variant} padding={0.5}>
<ContentAction
icon={SvgTerminal}
title={middleText ? `${title} ${middleText}` : title}
description="Built-in Python runtime"
variant="section"
sizePreset="main-ui"
rightChildren={rightContent}
/>
</Card>
);
}
function CheckingStatus() {
return (
<Section
flexDirection="row"
justifyContent="end"
alignItems="center"
gap={0.25}
padding={0.5}
>
<Text mainUiAction text03>
Checking...
</Text>
<SimpleLoader />
</Section>
);
}
interface ConnectionStatusProps {
healthy: boolean;
isLoading: boolean;
}
function ConnectionStatus({ healthy, isLoading }: ConnectionStatusProps) {
if (isLoading) {
return <CheckingStatus />;
}
const label = healthy ? "Connected" : "Connection Lost";
const Icon = healthy ? SvgCheckCircle : SvgXOctagon;
const iconColor = healthy ? "text-status-success-05" : "text-status-error-05";
return (
<Section
flexDirection="row"
justifyContent="end"
alignItems="center"
gap={0.25}
padding={0.5}
>
<Text mainUiAction text03>
{label}
</Text>
<Icon size={16} className={iconColor} />
</Section>
);
}
interface ActionButtonsProps {
onDisconnect: () => void;
onRefresh: () => void;
disabled?: boolean;
}
function ActionButtons({
onDisconnect,
onRefresh,
disabled,
}: ActionButtonsProps) {
return (
<Section
flexDirection="row"
justifyContent="end"
alignItems="center"
gap={0.25}
padding={0.25}
>
<Button
prominence="tertiary"
size="sm"
icon={SvgUnplug}
onClick={onDisconnect}
tooltip="Disconnect"
disabled={disabled}
/>
<Button
prominence="tertiary"
size="sm"
icon={SvgRefreshCw}
onClick={onRefresh}
tooltip="Refresh"
disabled={disabled}
/>
</Section>
);
}
export default function CodeInterpreterPage() {
const { isHealthy, isEnabled, isLoading, refetch } = useCodeInterpreter();
const [showDisconnectModal, setShowDisconnectModal] = useState(false);
const [isReconnecting, setIsReconnecting] = useState(false);
async function handleToggle(enabled: boolean) {
const action = enabled ? "reconnect" : "disconnect";
setIsReconnecting(enabled);
try {
const response = await updateCodeInterpreter({ enabled });
if (!response.ok) {
toast.error(`Failed to ${action} Code Interpreter`);
return;
}
setShowDisconnectModal(false);
refetch();
} finally {
setIsReconnecting(false);
}
}
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgTerminal}
title="Code Interpreter"
description="Safe and sandboxed Python runtime available to your LLM. See docs for more details."
separator
/>
<SettingsLayouts.Body>
{isEnabled || isLoading ? (
<CodeInterpreterCard
title="Code Interpreter"
variant={isHealthy ? "primary" : "secondary"}
strikethrough={!isHealthy}
rightContent={
<Section
flexDirection="column"
justifyContent="center"
alignItems="end"
gap={0}
padding={0}
>
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
<ActionButtons
onDisconnect={() => setShowDisconnectModal(true)}
onRefresh={refetch}
disabled={isLoading}
/>
</Section>
}
/>
) : (
<CodeInterpreterCard
variant="secondary"
title="Code Interpreter"
middleText="(Disconnected)"
strikethrough={true}
rightContent={
<Section flexDirection="row" alignItems="center" padding={0.5}>
{isReconnecting ? (
<CheckingStatus />
) : (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={() => handleToggle(true)}
>
Reconnect
</Button>
)}
</Section>
}
/>
)}
</SettingsLayouts.Body>
{showDisconnectModal && (
<ConfirmationModalLayout
icon={SvgUnplug}
title="Disconnect Code Interpreter"
onClose={() => setShowDisconnectModal(false)}
submit={
<Button variant="danger" onClick={() => handleToggle(false)}>
Disconnect
</Button>
}
>
<Text as="p" text03>
All running sessions connected to{" "}
<Text as="span" mainContentEmphasis text03>
Code Interpreter
</Text>{" "}
will stop working. Note that this will not remove any data from your
runtime. You can reconnect to this runtime later if needed.
</Text>
</ConfirmationModalLayout>
)}
</SettingsLayouts.Root>
);
}

View File

@@ -11,11 +11,7 @@ import { cn, noProp } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { Route } from "next";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import {
checkUserOwnsAssistant,
updateAgentSharedStatus,
updateAgentFeaturedStatus,
} from "@/lib/agents";
import { checkUserOwnsAssistant, updateAgentSharedStatus } from "@/lib/agents";
import { useUser } from "@/providers/UserProvider";
import {
SvgActions,
@@ -47,9 +43,8 @@ export default function AgentCard({ agent }: AgentCardProps) {
() => pinnedAgents.some((pinnedAgent) => pinnedAgent.id === agent.id),
[agent.id, pinnedAgents]
);
const { user, isAdmin, isCurator } = useUser();
const { user } = useUser();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const canUpdateFeaturedStatus = isAdmin || isCurator;
const isOwnedByUser = checkUserOwnsAssistant(user, agent);
const shareAgentModal = useCreateModal();
const agentViewerModal = useCreateModal();
@@ -63,49 +58,26 @@ export default function AgentCard({ agent }: AgentCardProps) {
route({ agentId: agent.id });
}, [pinned, togglePinnedAgent, agent, route]);
// Handle sharing agent
const handleShare = useCallback(
async (
userIds: string[],
groupIds: number[],
isPublic: boolean,
isFeatured: boolean,
labelIds: number[]
) => {
const shareError = await updateAgentSharedStatus(
async (userIds: string[], groupIds: number[], isPublic: boolean) => {
const error = await updateAgentSharedStatus(
agent.id,
userIds,
groupIds,
isPublic,
isPaidEnterpriseFeaturesEnabled,
labelIds
isPaidEnterpriseFeaturesEnabled
);
if (shareError) {
toast.error(`Failed to share agent: ${shareError}`);
return;
if (error) {
toast.error(`Failed to share agent: ${error}`);
} else {
// Revalidate the agent data to reflect the changes
refreshAgent();
shareAgentModal.toggle(false);
}
if (canUpdateFeaturedStatus) {
const featuredError = await updateAgentFeaturedStatus(
agent.id,
isFeatured
);
if (featuredError) {
toast.error(`Failed to update featured status: ${featuredError}`);
refreshAgent();
return;
}
}
refreshAgent();
shareAgentModal.toggle(false);
},
[
agent.id,
canUpdateFeaturedStatus,
isPaidEnterpriseFeaturesEnabled,
refreshAgent,
]
[agent.id, isPaidEnterpriseFeaturesEnabled, refreshAgent]
);
return (
@@ -116,8 +88,6 @@ export default function AgentCard({ agent }: AgentCardProps) {
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
groupIds={fullAgent?.groups ?? []}
isPublic={fullAgent?.is_public ?? false}
isFeatured={fullAgent?.is_default_persona ?? false}
labelIds={fullAgent?.labels?.map((l) => l.id) ?? []}
onShare={handleShare}
/>
</shareAgentModal.Provider>

View File

@@ -119,7 +119,7 @@ export default function NewTenantModal({
: `Your request to join ${tenantInfo.number_of_users} other users of ${APP_DOMAIN} has been approved.`;
const description = isInvite
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current agents, prompts, chats, and connected sources.`
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current assistants, prompts, chats, and connected sources.`
: `To finish joining your team, please reauthenticate with ${user?.email}.`;
return (

View File

@@ -1,18 +1,15 @@
"use client";
import { useCallback, useMemo, useState } from "react";
import { useMemo } from "react";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import {
SvgLink,
SvgOrganization,
SvgShare,
SvgTag,
SvgUsers,
SvgX,
} from "@opal/icons";
import InputChipField from "@/refresh-components/inputs/InputChipField";
import Message from "@/refresh-components/messages/Message";
import Tabs from "@/refresh-components/Tabs";
import { Card } from "@/refresh-components/cards";
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
@@ -29,8 +26,6 @@ import { useUser } from "@/providers/UserProvider";
import { Formik, useFormikContext } from "formik";
import { useAgent } from "@/hooks/useAgents";
import { Button as OpalButton } from "@opal/components";
import { useLabels } from "@/lib/hooks";
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
const YOUR_ORGANIZATION_TAB = "Your Organization";
const USERS_AND_GROUPS_TAB = "Users & Groups";
@@ -43,8 +38,6 @@ interface ShareAgentFormValues {
selectedUserIds: string[];
selectedGroupIds: number[];
isPublic: boolean;
isFeatured: boolean;
labelIds: number[];
}
// ============================================================================
@@ -60,15 +53,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
useFormikContext<ShareAgentFormValues>();
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
const { data: groupsData } = useShareableGroups();
const { user: currentUser, isAdmin, isCurator } = useUser();
const { user: currentUser } = useUser();
const { agent: fullAgent } = useAgent(agentId ?? null);
const shareAgentModal = useModal();
const { labels: allLabels, createLabel } = useLabels();
const [labelInputValue, setLabelInputValue] = useState("");
const acceptedUsers = usersData ?? [];
const groups = groupsData ?? [];
const canUpdateFeaturedStatus = isAdmin || isCurator;
// Create options for InputComboBox from all accepted users and groups
const comboBoxOptions = useMemo(() => {
@@ -147,50 +137,6 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
);
}
const selectedLabels: PersonaLabel[] = useMemo(() => {
if (!allLabels) return [];
return allLabels.filter((label) => values.labelIds.includes(label.id));
}, [allLabels, values.labelIds]);
function handleRemoveLabel(labelId: number) {
setFieldValue(
"labelIds",
values.labelIds.filter((id) => id !== labelId)
);
}
const addLabel = useCallback(
async (name: string) => {
const trimmed = name.trim();
if (!trimmed) return;
const existing = allLabels?.find(
(l) => l.name.toLowerCase() === trimmed.toLowerCase()
);
if (existing) {
if (!values.labelIds.includes(existing.id)) {
setFieldValue("labelIds", [...values.labelIds, existing.id]);
}
} else {
const newLabel = await createLabel(trimmed);
if (newLabel) {
setFieldValue("labelIds", [...values.labelIds, newLabel.id]);
}
}
setLabelInputValue("");
},
[allLabels, values.labelIds, setFieldValue, createLabel]
);
const chipItems = useMemo(
() =>
selectedLabels.map((label) => ({
id: String(label.id),
label: label.name,
})),
[selectedLabels]
);
return (
<Modal.Content width="sm" height="lg">
<Modal.Header icon={SvgShare} title="Share Agent" onClose={handleClose} />
@@ -284,56 +230,15 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
</Section>
)}
</Section>
{values.isPublic && (
<Section>
<Message
iconComponent={SvgOrganization}
close={false}
static
className="w-full"
text="This agent is public to your organization."
description="Everyone in your organization has access to this agent."
/>
</Section>
)}
</Tabs.Content>
<Tabs.Content value={YOUR_ORGANIZATION_TAB} padding={0.5}>
<Section gap={1} alignItems="stretch">
<InputLayouts.Horizontal
title="Publish This Agent"
description="Make this agent available to everyone in your organization."
>
<SwitchField name="isPublic" />
</InputLayouts.Horizontal>
{canUpdateFeaturedStatus && (
<>
<div className="border-t border-border-02" />
<InputLayouts.Horizontal
title="Feature This Agent"
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
>
<SwitchField name="isFeatured" />
</InputLayouts.Horizontal>
</>
)}
<InputChipField
chips={chipItems}
onRemoveChip={(id) => handleRemoveLabel(Number(id))}
onAdd={addLabel}
value={labelInputValue}
onChange={setLabelInputValue}
placeholder="Add labels..."
icon={SvgTag}
/>
<Text secondaryBody text04>
Add labels and categories to help people better discover this
agent.
</Text>
</Section>
<InputLayouts.Horizontal
title="Publish This Agent"
description="Make this agent available to everyone in your organization."
>
<SwitchField name="isPublic" />
</InputLayouts.Horizontal>
</Tabs.Content>
</Tabs>
</Card>
@@ -373,15 +278,7 @@ export interface ShareAgentModalProps {
userIds: string[];
groupIds: number[];
isPublic: boolean;
isFeatured: boolean;
labelIds: number[];
onShare?: (
userIds: string[],
groupIds: number[],
isPublic: boolean,
isFeatured: boolean,
labelIds: number[]
) => void;
onShare?: (userIds: string[], groupIds: number[], isPublic: boolean) => void;
}
export default function ShareAgentModal({
@@ -389,8 +286,6 @@ export default function ShareAgentModal({
userIds,
groupIds,
isPublic,
isFeatured,
labelIds,
onShare,
}: ShareAgentModalProps) {
const shareAgentModal = useModal();
@@ -399,18 +294,10 @@ export default function ShareAgentModal({
selectedUserIds: userIds,
selectedGroupIds: groupIds,
isPublic: isPublic,
isFeatured: isFeatured,
labelIds: labelIds,
};
function handleSubmit(values: ShareAgentFormValues) {
onShare?.(
values.selectedUserIds,
values.selectedGroupIds,
values.isPublic,
values.isFeatured,
values.labelIds
);
onShare?.(values.selectedUserIds, values.selectedGroupIds, values.isPublic);
}
return (

View File

@@ -50,7 +50,6 @@ import {
SvgPaintBrush,
SvgDiscordMono,
SvgWallet,
SvgTerminal,
} from "@opal/icons";
import SvgMcp from "@opal/icons/mcp";
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
@@ -92,7 +91,7 @@ const custom_assistants_items = (
) => {
const items = [
{
name: "Agents",
name: "Assistants",
icon: SvgOnyxOctagon,
link: "/admin/assistants",
},
@@ -166,7 +165,7 @@ const collections = (
]
: []),
{
name: "Custom Agents",
name: "Custom Assistants",
items: custom_assistants_items(isCurator, enableEnterprise),
},
...(isCurator && enableEnterprise
@@ -208,11 +207,6 @@ const collections = (
icon: SvgImage,
link: "/admin/configuration/image-generation",
},
{
name: "Code Interpreter",
icon: SvgTerminal,
link: "/admin/configuration/code-interpreter",
},
...(!enableCloud && vectorDbEnabled
? [
{

View File

@@ -29,12 +29,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
pageTitle: "Add Connector",
},
{
name: "Custom Agents - Agents",
name: "Custom Assistants - Assistants",
path: "assistants",
pageTitle: "Agents",
pageTitle: "Assistants",
options: {
paragraphText:
"Agents are a way to build custom search/question-answering experiences for different use cases.",
"Assistants are a way to build custom search/question-answering experiences for different use cases.",
},
},
{
@@ -52,7 +52,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
},
},
{
name: "Custom Agents - Slack Bots",
name: "Custom Assistants - Slack Bots",
path: "bots",
pageTitle: "Slack Bots",
options: {
@@ -61,7 +61,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
},
},
{
name: "Custom Agents - Standard Answers",
name: "Custom Assistants - Standard Answers",
path: "standard-answer",
pageTitle: "Standard Answers",
},
@@ -101,12 +101,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
pageTitle: "Search Settings",
},
{
name: "Custom Agents - MCP Actions",
name: "Custom Assistants - MCP Actions",
path: "actions/mcp",
pageTitle: "MCP Actions",
},
{
name: "Custom Agents - OpenAPI Actions",
name: "Custom Assistants - OpenAPI Actions",
path: "actions/open-api",
pageTitle: "OpenAPI Actions",
},

View File

@@ -1,268 +0,0 @@
import { test, expect } from "@playwright/test";
import type { Page } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
const CODE_INTERPRETER_URL = "/admin/configuration/code-interpreter";
const API_STATUS_URL = "**/api/admin/code-interpreter";
const API_HEALTH_URL = "**/api/admin/code-interpreter/health";
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/**
* Intercept the status (GET /) and health (GET /health) endpoints with the
* given values so the page renders deterministically.
*
* Also handles PUT requests — by default they succeed (200). Pass
* `putStatus` to simulate failures.
*/
async function mockCodeInterpreterApi(
page: Page,
opts: { enabled: boolean; healthy: boolean; putStatus?: number }
) {
const putStatus = opts.putStatus ?? 200;
await page.route(API_HEALTH_URL, async (route) => {
await route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ healthy: opts.healthy }),
});
});
await page.route(API_STATUS_URL, async (route) => {
if (route.request().method() === "PUT") {
await route.fulfill({
status: putStatus,
contentType: "application/json",
body:
putStatus >= 400
? JSON.stringify({ detail: "Server Error" })
: JSON.stringify(null),
});
} else {
await route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ enabled: opts.enabled }),
});
}
});
}
/**
* The disconnect icon button is an icon-only opal Button whose tooltip text
* is not exposed as an accessible name. Locate it by finding the first
* icon-only button (no label span) inside the card area.
*/
function getDisconnectIconButton(page: Page) {
return page
.locator("button:has(.opal-button):not(:has(.opal-button-label))")
.first();
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
test.describe("Code Interpreter Admin Page", () => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAs(page, "admin");
});
test("page loads with header and description", async ({ page }) => {
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
/^Code Interpreter/,
{ timeout: 10000 }
);
await expect(page.getByText("Built-in Python runtime")).toBeVisible();
});
test("shows Connected status when enabled and healthy", async ({ page }) => {
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
});
test("shows Connection Lost when enabled but unhealthy", async ({ page }) => {
await mockCodeInterpreterApi(page, { enabled: true, healthy: false });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByText("Connection Lost")).toBeVisible({
timeout: 10000,
});
});
test("shows Reconnect button when disabled", async ({ page }) => {
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
timeout: 10000,
});
await expect(page.getByText("(Disconnected)")).toBeVisible();
});
test("disconnect flow opens modal and sends PUT request", async ({
page,
}) => {
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
// Click the disconnect icon button
await getDisconnectIconButton(page).click();
// Modal should appear
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
await expect(
page.getByText("All running sessions connected to")
).toBeVisible();
// Click the danger Disconnect button in the modal
const modal = page.getByRole("dialog");
await modal.getByRole("button", { name: "Disconnect" }).click();
// Modal should close after successful disconnect
await expect(page.getByText("Disconnect Code Interpreter")).not.toBeVisible(
{ timeout: 5000 }
);
});
test("disconnect modal can be closed without disconnecting", async ({
page,
}) => {
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
// Open modal
await getDisconnectIconButton(page).click();
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
// Close modal via Cancel button
const modal = page.getByRole("dialog");
await modal.getByRole("button", { name: "Cancel" }).click();
// Modal should be gone, page still shows Connected
await expect(
page.getByText("Disconnect Code Interpreter")
).not.toBeVisible();
await expect(page.getByText("Connected")).toBeVisible();
});
test("reconnect flow sends PUT with enabled=true", async ({ page }) => {
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
timeout: 10000,
});
// Intercept the PUT and verify the payload
const putPromise = page.waitForRequest(
(req) =>
req.url().includes("/api/admin/code-interpreter") &&
req.method() === "PUT"
);
await page.getByRole("button", { name: "Reconnect" }).click();
const putReq = await putPromise;
expect(putReq.postDataJSON()).toEqual({ enabled: true });
});
test("shows Checking... while reconnect is in progress", async ({ page }) => {
// Use a single route handler that delays PUT responses
await page.route(API_HEALTH_URL, async (route) => {
await route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ healthy: false }),
});
});
await page.route(API_STATUS_URL, async (route) => {
if (route.request().method() === "PUT") {
await new Promise((resolve) => setTimeout(resolve, 2000));
await route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify(null),
});
} else {
await route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ enabled: false }),
});
}
});
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
timeout: 10000,
});
await page.getByRole("button", { name: "Reconnect" }).click();
// Should show Checking... while the request is in flight
await expect(page.getByText("Checking...")).toBeVisible({ timeout: 3000 });
});
test("shows error toast when disconnect fails", async ({ page }) => {
await mockCodeInterpreterApi(page, {
enabled: true,
healthy: true,
putStatus: 500,
});
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
// Open modal and click disconnect
await getDisconnectIconButton(page).click();
const modal = page.getByRole("dialog");
await modal.getByRole("button", { name: "Disconnect" }).click();
// Error toast should appear
await expect(
page.getByText("Failed to disconnect Code Interpreter")
).toBeVisible({ timeout: 5000 });
});
test("shows error toast when reconnect fails", async ({ page }) => {
await mockCodeInterpreterApi(page, {
enabled: false,
healthy: false,
putStatus: 500,
});
await page.goto(CODE_INTERPRETER_URL);
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
timeout: 10000,
});
await page.getByRole("button", { name: "Reconnect" }).click();
// Error toast should appear
await expect(
page.getByText("Failed to reconnect Code Interpreter")
).toBeVisible({ timeout: 5000 });
// Reconnect button should reappear (not stuck in Checking...)
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
timeout: 5000,
});
});
});

View File

@@ -46,7 +46,7 @@ test.skip("User changes password and logs in with new password", async ({
// Verify successful login
await expect(page).toHaveURL("http://localhost:3000/app");
await expect(page.getByText("Explore Agents")).toBeVisible();
await expect(page.getByText("Explore Assistants")).toBeVisible();
});
test.use({ storageState: "admin2_auth.json" });
@@ -115,5 +115,5 @@ test.skip("Admin resets own password and logs in with new password", async ({
// Verify successful login
await expect(page).toHaveURL("http://localhost:3000/app");
await expect(page.getByText("Explore Agents")).toBeVisible();
await expect(page.getByText("Explore Assistants")).toBeVisible();
});

View File

@@ -16,7 +16,7 @@ import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
// Tool-related test selectors now imported from shared utils
test.describe("Default Agent Tests", () => {
test.describe("Default Assistant Tests", () => {
let imageGenConfigId: string | null = null;
test.beforeAll(async ({ browser }) => {
@@ -69,7 +69,7 @@ test.describe("Default Agent Tests", () => {
});
test.describe("Greeting Message Display", () => {
test("should display greeting message when opening new chat with default agent", async ({
test("should display greeting message when opening new chat with default assistant", async ({
page,
}) => {
// Look for greeting message - should be one from the predefined list
@@ -95,21 +95,23 @@ test.describe("Default Agent Tests", () => {
expect(GREETING_MESSAGES).toContain(greetingAfterReload?.trim());
});
test("greeting should only appear for default agent", async ({ page }) => {
// First verify greeting appears for default agent
test("greeting should only appear for default assistant", async ({
page,
}) => {
// First verify greeting appears for default assistant
const greetingElement = await page.waitForSelector(
'[data-testid="onyx-logo"]',
{ timeout: 5000 }
);
expect(greetingElement).toBeTruthy();
// Create a custom agent to test non-default behavior
// Create a custom assistant to test non-default behavior
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
await page.locator('input[name="name"]').fill("Custom Test Agent");
await page.locator('input[name="name"]').fill("Custom Test Assistant");
await page
.locator('textarea[name="description"]')
.fill("Test Description");
@@ -118,17 +120,17 @@ test.describe("Default Agent Tests", () => {
.fill("Test Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Wait for agent to be created and selected
await verifyAssistantIsChosen(page, "Custom Test Agent");
// Wait for assistant to be created and selected
await verifyAssistantIsChosen(page, "Custom Test Assistant");
// Greeting should NOT appear for custom agent
// Greeting should NOT appear for custom assistant
const customGreeting = await page.$('[data-testid="onyx-logo"]');
expect(customGreeting).toBeNull();
});
});
test.describe("Default Agent Branding", () => {
test("should display Onyx logo for default agent", async ({ page }) => {
test.describe("Default Assistant Branding", () => {
test("should display Onyx logo for default assistant", async ({ page }) => {
// Look for Onyx logo
const logoElement = await page.waitForSelector(
'[data-testid="onyx-logo"]',
@@ -136,23 +138,23 @@ test.describe("Default Agent Tests", () => {
);
expect(logoElement).toBeTruthy();
// Should NOT show agent name for default agent
// Should NOT show assistant name for default assistant
const assistantNameElement = await page.$(
'[data-testid="assistant-name-display"]'
);
expect(assistantNameElement).toBeNull();
});
test("custom agents should show name and icon instead of logo", async ({
test("custom assistants should show name and icon instead of logo", async ({
page,
}) => {
// Create a custom agent
// Create a custom assistant
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
await page.locator('input[name="name"]').fill("Custom Agent");
await page.locator('input[name="name"]').fill("Custom Assistant");
await page
.locator('textarea[name="description"]')
.fill("Test Description");
@@ -161,16 +163,16 @@ test.describe("Default Agent Tests", () => {
.fill("Test Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Wait for agent to be created and selected
await verifyAssistantIsChosen(page, "Custom Agent");
// Wait for assistant to be created and selected
await verifyAssistantIsChosen(page, "Custom Assistant");
// Should show agent name and icon, not Onyx logo
// Should show assistant name and icon, not Onyx logo
const assistantNameElement = await page.waitForSelector(
'[data-testid="assistant-name-display"]',
{ timeout: 5000 }
);
const nameText = await assistantNameElement.textContent();
expect(nameText).toContain("Custom Agent");
expect(nameText).toContain("Custom Assistant");
// Onyx logo should NOT be shown
const logoElement = await page.$('[data-testid="onyx-logo"]');
@@ -179,8 +181,10 @@ test.describe("Default Agent Tests", () => {
});
test.describe("Starter Messages", () => {
test("default agent should NOT have starter messages", async ({ page }) => {
// Check that starter messages container does not exist for default agent
test("default assistant should NOT have starter messages", async ({
page,
}) => {
// Check that starter messages container does not exist for default assistant
const starterMessagesContainer = await page.$(
'[data-testid="starter-messages"]'
);
@@ -191,14 +195,18 @@ test.describe("Default Agent Tests", () => {
expect(starterButtons.length).toBe(0);
});
test("custom agents should display starter messages", async ({ page }) => {
// Create a custom agent with starter messages
test("custom assistants should display starter messages", async ({
page,
}) => {
// Create a custom assistant with starter messages
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
await page.locator('input[name="name"]').fill("Test Agent with Starters");
await page
.locator('input[name="name"]')
.fill("Test Assistant with Starters");
await page
.locator('textarea[name="description"]')
.fill("Test Description");
@@ -211,9 +219,9 @@ test.describe("Default Agent Tests", () => {
await page.getByRole("button", { name: "Create" }).click();
// Wait for assistant to be created and selected
await verifyAssistantIsChosen(page, "Test Agent with Starters");
await verifyAssistantIsChosen(page, "Test Assistant with Starters");
// Starter messages container might exist but be empty for custom agents
// Starter messages container might exist but be empty for custom assistants
const starterMessagesContainer = await page.$(
'[data-testid="starter-messages"]'
);
@@ -222,22 +230,24 @@ test.describe("Default Agent Tests", () => {
const starterButtons = await page.$$(
'[data-testid^="starter-message-"]'
);
// Custom agent without configured starter messages should have none
// Custom assistant without configured starter messages should have none
expect(starterButtons.length).toBe(0);
}
});
});
test.describe("Agent Selection", () => {
test("default agent should be selected for new chats", async ({ page }) => {
// Verify the input placeholder indicates default agent (Onyx)
test.describe("Assistant Selection", () => {
test("default assistant should be selected for new chats", async ({
page,
}) => {
// Verify the input placeholder indicates default assistant (Onyx)
await verifyDefaultAssistantIsChosen(page);
});
test("default agent should NOT appear in agent selector", async ({
test("default assistant should NOT appear in assistant selector", async ({
page,
}) => {
// Open agent selector
// Open assistant selector
await page.getByTestId("AppSidebar/more-agents").click();
// Wait for modal or assistant list to appear
@@ -246,13 +256,13 @@ test.describe("Default Agent Tests", () => {
.getByLabel("AgentsPage/new-agent-button")
.waitFor({ state: "visible", timeout: 5000 });
// Look for default agent by name - it should NOT be there
// Look for default assistant by name - it should NOT be there
const assistantElements = await page.$$('[data-testid^="assistant-"]');
const assistantTexts = await Promise.all(
assistantElements.map((el) => el.textContent())
);
// Check that the default agent is not in the list
// Check that "Assistant" (the default assistant name) is not in the list
const hasDefaultAssistant = assistantTexts.some(
(text) =>
text?.includes("Assistant") &&
@@ -265,16 +275,16 @@ test.describe("Default Agent Tests", () => {
await page.keyboard.press("Escape");
});
test("should be able to switch from default to custom agent", async ({
test("should be able to switch from default to custom assistant", async ({
page,
}) => {
// Create a custom agent
// Create a custom assistant
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
await page.locator('input[name="name"]').fill("Switch Test Agent");
await page.locator('input[name="name"]').fill("Switch Test Assistant");
await page
.locator('textarea[name="description"]')
.fill("Test Description");
@@ -283,13 +293,13 @@ test.describe("Default Agent Tests", () => {
.fill("Test Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Verify switched to custom agent
await verifyAssistantIsChosen(page, "Switch Test Agent");
// Verify switched to custom assistant
await verifyAssistantIsChosen(page, "Switch Test Assistant");
// Start new chat to go back to default
await startNewChat(page);
// Should be back to default agent
// Should be back to default assistant
await verifyDefaultAssistantIsChosen(page);
});
});
@@ -369,7 +379,7 @@ test.describe("Default Agent Tests", () => {
);
}
// Enable the tools in default agent config via API
// Enable the tools in default assistant config via API
// Get current tools to find their IDs
const toolsListResp = await page.request.get(
"http://localhost:3000/api/tool"
@@ -532,7 +542,7 @@ test.describe("Default Agent Tests", () => {
});
});
test.describe("End-to-End Default Agent Flow", () => {
test.describe("End-to-End Default Assistant Flow", () => {
let imageGenConfigId: string | null = null;
test.beforeAll(async ({ browser }) => {
@@ -574,7 +584,7 @@ test.describe("End-to-End Default Agent Flow", () => {
}
});
test("complete user journey with default agent", async ({ page }) => {
test("complete user journey with default assistant", async ({ page }) => {
// Clear cookies and log in as a random user
await page.context().clearCookies();
await loginAsRandomUser(page);
@@ -601,7 +611,7 @@ test.describe("End-to-End Default Agent Flow", () => {
// Start a new chat
await startNewChat(page);
// Verify we're back to default agent with greeting
// Verify we're back to default assistant with greeting
await expect(page.locator('[data-testid="onyx-logo"]')).toBeVisible();
});
});