mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 12:45:51 +00:00
Compare commits
2 Commits
main
...
table-prim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7efdd51028 | ||
|
|
8ca23de9e6 |
@@ -9,8 +9,7 @@ inputs:
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: false
|
||||
default: ""
|
||||
required: true
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
@@ -18,14 +17,6 @@ inputs:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
@@ -68,7 +59,6 @@ runs:
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
@@ -92,8 +82,6 @@ runs:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
@@ -103,6 +91,11 @@ runs:
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -117,13 +110,10 @@ runs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
|
||||
44
.github/workflows/nightly-llm-provider-chat-openai.yml
vendored
Normal file
44
.github/workflows/nightly-llm-provider-chat-openai.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -1,56 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -89,10 +89,6 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
@@ -3,66 +3,33 @@ name: Reusable Nightly LLM Provider Chat Tests
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
type: string
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
type: string
|
||||
strict:
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
@@ -71,8 +38,29 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -102,6 +90,7 @@ jobs:
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -130,6 +119,7 @@ jobs:
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -159,75 +149,11 @@ jobs:
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -241,14 +167,12 @@ jobs:
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
@@ -270,7 +194,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -322,7 +322,6 @@ def list_users(
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -366,7 +365,6 @@ def get_user(
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
@@ -723,7 +721,6 @@ def list_groups(
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -760,7 +757,6 @@ def get_group(
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
|
||||
@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -162,11 +163,13 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
func: Callable[..., None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -177,18 +180,19 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -197,7 +201,9 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for context files.
|
||||
"""Build citation mapping for project files.
|
||||
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
file_metadata: List of context file metadata
|
||||
project_file_metadata: List of project file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,7 +221,8 @@ def _build_context_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -241,28 +242,29 @@ def _build_context_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
context_files: ExtractedContextFiles | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
"""Build messages for project / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
"""
|
||||
if not context_files:
|
||||
if not project_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if context_files.file_texts:
|
||||
if project_files.project_file_texts:
|
||||
messages.append(
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
)
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -273,7 +275,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -287,7 +289,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -443,17 +445,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if context_files and context_files.image_files:
|
||||
if project_files and project_files.project_image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -464,14 +466,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with context images attached)
|
||||
# 5. Add last user message (with project images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -545,11 +547,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -557,7 +559,7 @@ def _create_context_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -568,10 +570,10 @@ def _create_context_files_message(
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from context_files
|
||||
# Use pre-calculated token count from project_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=context_files.total_token_count,
|
||||
token_count=project_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -582,7 +584,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
context_files: ExtractedContextFiles,
|
||||
project_files: ExtractedProjectFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -625,9 +627,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -645,7 +647,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -786,7 +788,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -31,6 +31,13 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -125,8 +132,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -160,28 +167,20 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
|
||||
@@ -3,7 +3,6 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -34,11 +33,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -63,12 +62,11 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -141,12 +139,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
user_files = get_user_files_from_project(
|
||||
project_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in user_files:
|
||||
for uf in project_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -194,67 +192,9 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -263,12 +203,8 @@ def extract_context_files(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -277,95 +213,160 @@ def extract_context_files(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
total_token_count=total_token_count,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -380,46 +381,55 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
"""
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
)
|
||||
|
||||
|
||||
@@ -651,37 +661,26 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -690,17 +689,30 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -710,8 +722,11 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -729,7 +744,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -768,7 +783,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -864,54 +879,46 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
run_deep_research_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
run_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -23,6 +23,7 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
|
||||
@@ -871,56 +872,6 @@ class SharepointConnector(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
|
||||
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
|
||||
tenant domain is the first label of the hostname.
|
||||
"""
|
||||
for site_url in self.sites:
|
||||
try:
|
||||
hostname = urlsplit(site_url.strip()).hostname
|
||||
except ValueError:
|
||||
continue
|
||||
if not hostname:
|
||||
continue
|
||||
tenant = hostname.split(".")[0]
|
||||
if tenant:
|
||||
return tenant
|
||||
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
|
||||
return None
|
||||
|
||||
def _resolve_tenant_domain_from_root_site(self) -> str:
|
||||
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
|
||||
Sites.Read.All (a permission the connector already needs)."""
|
||||
root_site = self.graph_client.sites.root.get().execute_query()
|
||||
hostname = root_site.site_collection.hostname
|
||||
if not hostname:
|
||||
raise ConnectorValidationError(
|
||||
"Could not determine tenant domain from root site"
|
||||
)
|
||||
tenant_domain = hostname.split(".")[0]
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from root site hostname '%s'",
|
||||
tenant_domain,
|
||||
hostname,
|
||||
)
|
||||
return tenant_domain
|
||||
|
||||
def _resolve_tenant_domain(self) -> str:
|
||||
"""Determine the tenant domain, preferring site URLs over a Graph API
|
||||
call to avoid needing extra permissions."""
|
||||
from_sites = self._extract_tenant_domain_from_sites()
|
||||
if from_sites:
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from site URLs",
|
||||
from_sites,
|
||||
)
|
||||
return from_sites
|
||||
|
||||
logger.info("No site URLs available; resolving tenant domain from root site")
|
||||
return self._resolve_tenant_domain_from_root_site()
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
@@ -1638,11 +1589,6 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
if not sp_client_id:
|
||||
raise ConnectorValidationError("Client ID is required")
|
||||
if not sp_directory_id:
|
||||
raise ConnectorValidationError("Directory (tenant) ID is required")
|
||||
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
@@ -1695,7 +1641,21 @@ class SharepointConnector(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
raise ConnectorValidationError("No organization found")
|
||||
|
||||
tenant_info: Organization = org[
|
||||
0
|
||||
] # Access first item directly from collection
|
||||
if not tenant_info.verified_domains:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
|
||||
sp_tenant_domain = tenant_info.verified_domains[0].name
|
||||
if not sp_tenant_domain:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
# remove the .onmicrosoft.com part
|
||||
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import ssl_context
|
||||
from onyx.db.engine.sql_engine import ASYNC_DB_API
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = create_ssl_context_if_iam()
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = create_ssl_context_if_iam()
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any
|
||||
@@ -49,9 +48,11 @@ def provide_iam_token(
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
@@ -256,6 +256,9 @@ def create_update_persona(
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
@@ -332,7 +335,6 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -342,7 +344,9 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -356,15 +360,6 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -970,8 +965,6 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
@@ -1168,6 +1161,9 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
@@ -58,19 +57,12 @@ def fetch_user_project_ids_for_user_files(
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
|
||||
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
|
||||
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
user_file_id_to_project_ids: dict[str, list[int]] = {
|
||||
user_file_id: [] for user_file_id in user_file_ids
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
for user_file in results
|
||||
}
|
||||
for user_file_id, project_id in rows:
|
||||
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
|
||||
|
||||
return user_file_id_to_project_ids
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_final_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -405,7 +405,6 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -416,22 +415,14 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Per-tenant request counter metric.
|
||||
|
||||
Increments a counter on every request, labelled by tenant, so Grafana can
|
||||
answer "which tenant is generating the most traffic?"
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_requests_by_tenant = Counter(
|
||||
"onyx_api_requests_by_tenant_total",
|
||||
"Total API requests by tenant",
|
||||
["tenant_id", "method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def per_tenant_request_callback(info: Info) -> None:
|
||||
"""Increment per-tenant request counter for every request."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
_requests_by_tenant.labels(
|
||||
tenant_id=tenant_id,
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
@@ -32,7 +32,6 @@ from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -73,7 +72,7 @@ _checkout_timeout_total = Counter(
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine", "tenant_id"],
|
||||
["handler", "engine"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
@@ -164,14 +163,10 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_tenant_id"] = tenant_id
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
@@ -179,12 +174,9 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
@@ -207,12 +199,9 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
|
||||
@@ -11,11 +11,9 @@ SQLAlchemy connection pool metrics are registered separately via
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -61,15 +59,6 @@ def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
# Explicitly create the default metrics (http_requests_total,
|
||||
# http_request_duration_seconds, etc.) and add them first. The library
|
||||
# skips creating defaults when ANY custom instrumentations are registered
|
||||
# via .add(), so we must include them ourselves.
|
||||
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
|
||||
if default_callback:
|
||||
instrumentator.add(default_callback)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
instrumentator.add(per_tenant_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
|
||||
@@ -120,7 +120,7 @@ def generate_intermediate_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ def run_research_agent_call(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
context_files=None,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.3
|
||||
pypdf==6.6.2
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
@@ -522,46 +521,3 @@ def test_sharepoint_connector_hierarchy_nodes(
|
||||
f"Document {doc.semantic_identifier} should have "
|
||||
"parent_hierarchy_raw_node_id set"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_cert_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
|
||||
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
|
||||
"sp_certificate_password": os.environ[
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
],
|
||||
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_site_urls(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain from site URLs
|
||||
without calling the /organization endpoint."""
|
||||
site_url = os.environ["SHAREPOINT_SITE"]
|
||||
connector = SharepointConnector(sites=[site_url])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
# The tenant domain should match the first label of the site URL hostname
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
|
||||
assert connector.sp_tenant_domain == expected
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_root_site(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain via the root
|
||||
site endpoint when no site URLs are configured."""
|
||||
connector = SharepointConnector(sites=[])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
|
||||
@@ -1,544 +0,0 @@
|
||||
"""
|
||||
External dependency unit tests for persona file sync.
|
||||
|
||||
Validates that:
|
||||
|
||||
1. The check_for_user_file_project_sync beat task picks up UserFiles with
|
||||
needs_persona_sync=True (not just needs_project_sync).
|
||||
|
||||
2. The process_single_user_file_project_sync worker task reads persona
|
||||
associations from the DB, passes persona_ids to the document index via
|
||||
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
|
||||
|
||||
3. upsert_persona correctly marks affected UserFiles with
|
||||
needs_persona_sync=True when file associations change.
|
||||
|
||||
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
|
||||
since we only need to verify the arguments passed to update_single.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_completed_user_file(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
needs_persona_sync: bool = False,
|
||||
needs_project_sync: bool = False,
|
||||
) -> UserFile:
|
||||
"""Insert a UserFile in COMPLETED status."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
needs_project_sync=needs_project_sync,
|
||||
chunk_count=5,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_test_persona(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user_files: list[UserFile] | None = None,
|
||||
) -> Persona:
|
||||
"""Create a minimal Persona via direct model insert."""
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_files=user_files or [],
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _link_file_to_persona(
|
||||
db_session: Session, persona: Persona, user_file: UserFile
|
||||
) -> None:
|
||||
"""Create the join table row between a persona and a user file."""
|
||||
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
_PATCH_QUEUE_DEPTH = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
".get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on a bound Celery task."""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: check_for_user_file_project_sync picks up persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSweepIncludesPersonaSync:
|
||||
"""The beat task must pick up files needing persona sync, not just project sync."""
|
||||
|
||||
def test_persona_sync_flag_enqueues_task(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
|
||||
user = create_test_user(db_session, "persona_sweep")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) in enqueued_ids
|
||||
|
||||
def test_neither_flag_does_not_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with both flags False is not enqueued."""
|
||||
user = create_test_user(db_session, "no_sync")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) not in enqueued_ids
|
||||
|
||||
def test_both_flags_enqueues_once(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with BOTH flags True is enqueued exactly once."""
|
||||
user = create_test_user(db_session, "both_flags")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
matching_calls = [
|
||||
call
|
||||
for call in mock_app.send_task.call_args_list
|
||||
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
|
||||
]
|
||||
assert len(matching_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: process_single_user_file_project_sync passes persona_ids to index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_GET_SETTINGS = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
|
||||
)
|
||||
_PATCH_GET_INDICES = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
|
||||
)
|
||||
_PATCH_HTTPX_INIT = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
|
||||
)
|
||||
_PATCH_DISABLE_VDB = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
|
||||
)
|
||||
|
||||
|
||||
class TestSyncTaskWritesPersonaIds:
|
||||
"""The sync task reads persona associations and sends them to the index."""
|
||||
|
||||
def test_passes_persona_ids_to_update_single(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After linking a file to a persona, sync sends the persona ID."""
|
||||
user = create_test_user(db_session, "sync_persona")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
mock_doc_index.update_single.assert_called_once()
|
||||
call_args = mock_doc_index.update_single.call_args
|
||||
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert call_args.args[0] == str(uf.id)
|
||||
|
||||
def test_clears_persona_sync_flag(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a successful sync the needs_persona_sync flag is cleared."""
|
||||
user = create_test_user(db_session, "sync_clear")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with patch(_PATCH_DISABLE_VDB, True):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
def test_passes_both_project_and_persona_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to both a project and a persona gets both IDs."""
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserProject
|
||||
|
||||
user = create_test_user(db_session, "sync_both")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
project = UserProject(user_id=user.id, name="test-project", instructions="")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
|
||||
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert user_fields.user_projects is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert project.id in user_fields.user_projects
|
||||
|
||||
# Both flags should be cleared
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.needs_project_sync is False
|
||||
|
||||
def test_deleted_persona_excluded_from_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
|
||||
user = create_test_user(db_session, "sync_deleted")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id not in user_fields.personas
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: upsert_persona marks files for persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpsertPersonaMarksSyncFlag:
|
||||
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
|
||||
|
||||
def test_creating_persona_with_files_marks_sync(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "upsert_create")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
|
||||
def test_updating_persona_files_marks_both_old_and_new(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When file associations change, both the removed and added files are flagged."""
|
||||
user = create_test_user(db_session, "upsert_update")
|
||||
uf_old = _create_completed_user_file(db_session, user)
|
||||
uf_new = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf_old.id],
|
||||
)
|
||||
|
||||
# Clear the flag from creation so we can observe the update
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[uf_new.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf_old)
|
||||
db_session.refresh(uf_new)
|
||||
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
|
||||
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
|
||||
|
||||
def test_removing_all_files_marks_old_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Removing all files from a persona flags the previously associated files."""
|
||||
user = create_test_user(db_session, "upsert_remove")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
@@ -1,318 +0,0 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
|
||||
Mocks the LLM tokenizer and file store since they are not relevant here.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_user_file(db_session: Session, user: User) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
chunk_count=1,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _create_project(db_session: Session, user: User) -> UserProject:
|
||||
project = UserProject(
|
||||
user_id=user.id,
|
||||
name=f"project-{uuid4().hex[:8]}",
|
||||
instructions="",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
|
||||
doc = Document(
|
||||
id=str(user_file.id),
|
||||
source=DocumentSource.USER_FILE,
|
||||
semantic_identifier=user_file.name,
|
||||
sections=[TextSection(text="test chunk content", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
source_document=doc,
|
||||
chunk_id=0,
|
||||
blurb="test chunk",
|
||||
content="test chunk content",
|
||||
source_links={0: ""},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.0] * 768,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_persona_gets_persona_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_persona")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_project_gets_project_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_project")
|
||||
uf = _create_user_file(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_both_gets_both_ids(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_both")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_with_no_associations_gets_empty_lists(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_empty")
|
||||
uf = _create_user_file(db_session, user)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_multiple_personas_all_appear(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to multiple personas should have all their IDs."""
|
||||
user = create_test_user(db_session, "adapter_multi")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona_a = _create_persona(db_session, user)
|
||||
persona_b = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
|
||||
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
@@ -76,12 +76,9 @@ class ChatSessionManager:
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
project_id: int | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id, description=description
|
||||
)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ScimTokenManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": name},
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
raw_token=data["raw_token"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active(
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -42,18 +42,6 @@ class DATestPAT(BaseModel):
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestScimToken(BaseModel):
|
||||
"""SCIM bearer token model for testing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
raw_token: str | None = None # Only present on initial creation
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
|
||||
@@ -23,8 +23,6 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
|
||||
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
@@ -36,8 +34,6 @@ class NightlyProviderConfig(BaseModel):
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
deployment_name: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
@@ -49,29 +45,17 @@ def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _parse_models_env(env_var: str) -> list[str]:
|
||||
raw_value = os.environ.get(env_var, "").strip()
|
||||
if not raw_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(raw_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_json = None
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
return [str(model).strip() for model in parsed_json if str(model).strip()]
|
||||
|
||||
return [part.strip() for part in raw_value.split(",") if part.strip()]
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _parse_models_env(_ENV_MODELS)
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
api_version = os.environ.get(_ENV_API_VERSION) or None
|
||||
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
@@ -90,8 +74,6 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
@@ -113,15 +95,10 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not (
|
||||
config.api_key or config.custom_config
|
||||
):
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
|
||||
f"provider '{config.provider}'"
|
||||
),
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
@@ -132,22 +109,6 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "azure":
|
||||
if not config.api_base:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
if not config.api_version:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -186,8 +147,6 @@ def _create_provider_payload(
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
@@ -195,8 +154,6 @@ def _create_provider_payload(
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"deployment_name": deployment_name,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
@@ -298,8 +255,6 @@ def _create_and_test_provider_for_model(
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
api_version=config.api_version,
|
||||
deployment_name=config.deployment_name,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
@@ -358,21 +313,10 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
failures: list[str] = []
|
||||
for model_name in config.model_names:
|
||||
try:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
failures.append(
|
||||
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
|
||||
)
|
||||
|
||||
if failures:
|
||||
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
"""
|
||||
Integration tests for the unified persona file context flow.
|
||||
|
||||
End-to-end tests that verify:
|
||||
1. Files can be uploaded and attached to a persona via API.
|
||||
2. The persona correctly reports its attached files.
|
||||
3. A chat session with a file-bearing persona processes without error.
|
||||
4. Precedence: custom persona files take priority over project files when
|
||||
the chat session is inside a project.
|
||||
|
||||
These tests run against a real Onyx deployment (all services running).
|
||||
File processing is asynchronous, so we poll the file status endpoint
|
||||
until files reach COMPLETED before chatting.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILE_PROCESSING_POLL_INTERVAL = 2
|
||||
|
||||
|
||||
def _poll_file_statuses(
|
||||
user_file_ids: list[str],
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus = UserFileStatus.COMPLETED,
|
||||
timeout: int = MAX_DELAY,
|
||||
) -> None:
|
||||
"""Block until all files reach the target status or timeout expires."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/statuses",
|
||||
json={"file_ids": user_file_ids},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
statuses = response.json()
|
||||
if all(f["status"] == target_status.value for f in statuses):
|
||||
return
|
||||
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
|
||||
raise TimeoutError(
|
||||
f"Files {user_file_ids} did not reach {target_status.value} "
|
||||
f"within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_with_files_chat_no_error(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Upload files, attach them to a persona, wait for processing,
|
||||
then send a chat message. Verify no error is returned."""
|
||||
|
||||
# Upload files (creates UserFile records)
|
||||
text_file = create_test_text_file(
|
||||
"The secret project codename is NIGHTINGALE. "
|
||||
"It was started in 2024 by the Advanced Research division."
|
||||
)
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("nightingale_brief.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
assert len(file_descriptors) == 1
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"]
|
||||
assert user_file_id is not None
|
||||
|
||||
# Wait for file processing
|
||||
_poll_file_statuses([user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with the file attached
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Nightingale Agent",
|
||||
description="Agent with secret file",
|
||||
system_prompt="You are a helpful assistant with access to uploaded files.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
# Verify persona has the file
|
||||
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
|
||||
assert len(persona_snapshots) == 1
|
||||
assert user_file_id in persona_snapshots[0].user_file_ids
|
||||
|
||||
# Chat with the persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test persona file context",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret project codename?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
assert len(response.full_message) > 0, "Response should not be empty"
|
||||
|
||||
|
||||
def test_persona_without_files_still_works(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A persona with no attached files should still chat normally."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Blank Agent",
|
||||
description="No files attached",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test blank persona",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="Hello, how are you?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_persona_files_override_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a custom persona (with its own files) is used inside a project,
|
||||
the persona's files take precedence — the project's files are invisible.
|
||||
|
||||
We verify this by putting different content in project vs persona files
|
||||
and checking which content the model responds with."""
|
||||
|
||||
# Upload persona file
|
||||
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
|
||||
persona_fds, err1 = FileManager.upload_files(
|
||||
files=[("persona_secret.txt", persona_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not err1
|
||||
persona_user_file_id = persona_fds[0]["user_file_id"]
|
||||
assert persona_user_file_id is not None
|
||||
# Create a project and upload project files
|
||||
project = ProjectManager.create(
|
||||
name="Precedence Test Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_secret.txt", b"The project's secret word is FLAMINGO."),
|
||||
]
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for both persona and project file processing
|
||||
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with persona file
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Override Agent",
|
||||
description="Persona with its own files",
|
||||
system_prompt="You are a helpful assistant. Answer using the files.",
|
||||
user_file_ids=[persona_user_file_id],
|
||||
)
|
||||
|
||||
# Create chat session inside the project but using the custom persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret word?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
# The persona's file should be what the model sees, not the project's
|
||||
message_lower = response.full_message.lower()
|
||||
assert "albatross" in message_lower, (
|
||||
"Response should reference the persona file's secret word (ALBATROSS), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_default_persona_in_project_uses_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the default persona (id=0) is used inside a project,
|
||||
the project's files should be used for context."""
|
||||
project = ProjectManager.create(
|
||||
name="Default Persona Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_info.txt", b"The project mascot is a PANGOLIN."),
|
||||
]
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(upload_result.user_files) == 1
|
||||
|
||||
# Wait for project file processing
|
||||
project_file_id = str(upload_result.user_files[0].id)
|
||||
_poll_file_statuses([project_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create chat session inside project using default persona (id=0)
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project mascot?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert "pangolin" in response.full_message.lower(), (
|
||||
"Response should reference the project file content (PANGOLIN), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_persona_no_files_in_project_ignores_project(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A custom persona with NO files, used inside a project with files,
|
||||
should NOT see the project's files. The project is purely organizational.
|
||||
|
||||
We verify by asking about content only in the project file and checking
|
||||
the model does NOT reference it."""
|
||||
|
||||
project = ProjectManager.create(
|
||||
name="Ignored Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for project file processing
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Custom persona with no files
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="No Files Agent",
|
||||
description="No files, project is irrelevant",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant. If you do not have information "
|
||||
"to answer a question, say 'I do not have that information.'"
|
||||
),
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project secret?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
assert "capybara" not in response.full_message.lower(), (
|
||||
"Response should NOT reference the project file content (CAPYBARA) "
|
||||
"because the custom persona has no files and should not inherit "
|
||||
f"project files, but got: {response.full_message}"
|
||||
)
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Integration tests for SCIM token management.
|
||||
|
||||
Covers the admin token API and SCIM bearer-token authentication:
|
||||
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
|
||||
2. Token rotation: creating a new token revokes previous tokens
|
||||
3. Revoked tokens are rejected by SCIM endpoints
|
||||
4. Non-admin users cannot manage SCIM tokens
|
||||
5. SCIM requests without a token are rejected
|
||||
6. Service discovery endpoints work without authentication
|
||||
7. last_used_at is updated after a SCIM request
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
"""Create token → retrieve metadata → use for SCIM request."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Test SCIM Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert token.raw_token is not None
|
||||
assert token.raw_token.startswith("onyx_scim_")
|
||||
assert token.is_active is True
|
||||
assert "****" in token.token_display
|
||||
|
||||
# GET returns the same metadata but raw_token is None because the
|
||||
# server only reveals the raw token once at creation time (it stores
|
||||
# only the SHA-256 hash).
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
assert body["totalResults"] >= 0
|
||||
|
||||
|
||||
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
"""Creating a new token automatically revokes the previous one."""
|
||||
first = ScimTokenManager.create(
|
||||
name="First Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
second = ScimTokenManager.create(
|
||||
name="Second Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert second.raw_token is not None
|
||||
|
||||
# Active token should now be the second one
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to create a SCIM token."""
|
||||
basic_user = UserManager.create(name="scim_basic_user")
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": "Should Fail"},
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_non_admin_cannot_get_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
|
||||
basic_user = UserManager.create(name="scim_basic_user2")
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
|
||||
"""GET active token returns 404 when no token exists."""
|
||||
# new_admin_user depends on the reset fixture, ensuring a clean DB
|
||||
# with no active SCIM tokens.
|
||||
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
|
||||
assert active is None
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=new_admin_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_service_discovery_no_auth_required(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
def test_last_used_at_updated_after_scim_request(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""last_used_at timestamp is updated after using the token."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Last Used Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert token.raw_token is not None
|
||||
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active is not None
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active_after is not None
|
||||
assert active_after.last_used_at is not None
|
||||
@@ -1,426 +0,0 @@
|
||||
"""Tests for the unified context file extraction logic (Phase 5).
|
||||
|
||||
Covers:
|
||||
- resolve_context_user_files: precedence rule (custom persona supersedes project)
|
||||
- extract_context_files: all-or-nothing context window fit check
|
||||
- Search filter / search_usage determination in the caller
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.process_message import determine_search_params
|
||||
from onyx.chat.process_message import extract_context_files
|
||||
from onyx.chat.process_message import resolve_context_user_files
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
token_count: int = 100,
|
||||
name: str = "file.txt",
|
||||
file_id: str | None = None,
|
||||
) -> UserFile:
|
||||
file_uuid = UUID(file_id) if file_id else uuid4()
|
||||
return UserFile(
|
||||
id=file_uuid,
|
||||
file_id=str(file_uuid),
|
||||
name=name,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(
|
||||
persona_id: int,
|
||||
user_files: list | None = None,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.user_files = user_files or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_in_memory_file(
|
||||
file_id: str,
|
||||
content: str = "hello world",
|
||||
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
|
||||
filename: str = "file.txt",
|
||||
) -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=content.encode("utf-8"),
|
||||
file_type=file_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# resolve_context_user_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveContextUserFiles:
|
||||
"""Precedence rule: custom persona fully supersedes project."""
|
||||
|
||||
def test_custom_persona_with_files_returns_persona_files(self) -> None:
|
||||
persona_files = [_make_user_file(), _make_user_file()]
|
||||
persona = _make_persona(persona_id=42, user_files=persona_files)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == persona_files
|
||||
|
||||
def test_custom_persona_without_files_returns_empty(self) -> None:
|
||||
"""Custom persona with no files should NOT fall through to project."""
|
||||
persona = _make_persona(persona_id=42, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_custom_persona_none_files_returns_empty(self) -> None:
|
||||
"""Custom persona with user_files=None should NOT fall through."""
|
||||
persona = _make_persona(persona_id=42, user_files=None)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_default_persona_in_project_returns_project_files(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
project_files = [_make_user_file(), _make_user_file()]
|
||||
mock_get_files.return_value = project_files
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
user_id = uuid4()
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
assert result == project_files
|
||||
mock_get_files.assert_called_once_with(
|
||||
project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
def test_default_persona_no_project_returns_empty(self) -> None:
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_custom_persona_without_files_ignores_project(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
"""Even with a project_id, custom persona means project is invisible."""
|
||||
persona = _make_persona(persona_id=7, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_get_files.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# extract_context_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractContextFiles:
|
||||
"""All-or-nothing context window fit check."""
|
||||
|
||||
def test_empty_user_files_returns_empty(self) -> None:
|
||||
db_session = MagicMock()
|
||||
result = extract_context_files(
|
||||
user_files=[],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.uncapped_token_count is None
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=100, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=file_id, content="file content")
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["file content"]
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.total_token_count == 100
|
||||
assert len(result.file_metadata) == 1
|
||||
assert result.file_metadata[0].file_id == file_id
|
||||
|
||||
def test_files_overflow_context_not_loaded(self) -> None:
|
||||
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
|
||||
uf = _make_user_file(token_count=7000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.uncapped_token_count == 7000
|
||||
assert result.total_token_count == 0
|
||||
|
||||
def test_overflow_boundary_exact(self) -> None:
|
||||
"""Token count exactly at the 60% boundary should trigger overflow."""
|
||||
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
|
||||
uf = _make_user_file(token_count=6000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
|
||||
"""Token count just under the 60% boundary should load files."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=5999, file_id=file_id)
|
||||
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.file_texts == ["data"]
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
|
||||
"""Multiple small files that individually fit but collectively overflow."""
|
||||
files = [_make_user_file(token_count=2500) for _ in range(3)]
|
||||
# 3 * 2500 = 7500 > 6000 threshold
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=files,
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
|
||||
"""Reserved tokens shrink the available window."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=3000, file_id=file_id)
|
||||
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=5000,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=50, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert len(result.image_files) == 1
|
||||
assert result.image_files[0].file_id == file_id
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that determine_search_params correctly resolves
|
||||
search_project_id, search_persona_id, and search_usage based on
|
||||
the extraction result and the precedence rule.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_context(
|
||||
use_as_search_filter: bool = False,
|
||||
file_texts: list[str] | None = None,
|
||||
uncapped_token_count: int | None = None,
|
||||
) -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts or [],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=uncapped_token_count,
|
||||
)
|
||||
|
||||
def test_custom_persona_files_fit_no_filter(self) -> None:
|
||||
"""Custom persona, files fit → no search filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
"""Custom persona, files overflow → persona_id filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(use_as_search_filter=True),
|
||||
)
|
||||
assert result.search_persona_id == 42
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
"""Custom persona (no files) in project → nothing leaks from project."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
"""Default persona, project files fit → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
"""Default persona, project files overflow → ENABLED + project_id filter."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
use_as_search_filter=True,
|
||||
uncapped_token_count=7000,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id == 99
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
"""Default persona, no project → AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=None,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
"""Default persona in project with no files → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -74,20 +74,20 @@ def create_tool_response(
|
||||
)
|
||||
|
||||
|
||||
def create_context_files(
|
||||
def create_project_files(
|
||||
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
|
||||
) -> ExtractedContextFiles:
|
||||
"""Helper to create ExtractedContextFiles for testing."""
|
||||
file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
file_metadata = [
|
||||
ContextFileMetadata(
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Helper to create ExtractedProjectFiles for testing."""
|
||||
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
project_file_metadata = [
|
||||
ProjectFileMetadata(
|
||||
file_id=f"file_{i}",
|
||||
filename=f"file_{i}.txt",
|
||||
file_content=f"Project file {i} content",
|
||||
)
|
||||
for i in range(num_files)
|
||||
]
|
||||
image_files = [
|
||||
project_image_files = [
|
||||
ChatLoadedFile(
|
||||
file_id=f"image_{i}",
|
||||
content=b"",
|
||||
@@ -98,13 +98,13 @@ def create_context_files(
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=False,
|
||||
total_token_count=num_files * tokens_per_file,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=num_files * tokens_per_file,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=num_files * tokens_per_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("How are you?", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
|
||||
assert result[3] == custom_agent # Before last user message
|
||||
assert result[4] == user_msg2
|
||||
|
||||
def test_with_context_files(self) -> None:
|
||||
def test_with_project_files(self) -> None:
|
||||
"""Test that project files are inserted before the last user message."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg1 = create_message("First message", MessageType.USER, 5)
|
||||
user_msg2 = create_message("Second message", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, context_files_message, user2
|
||||
# Should have: system, user1, project_files_message, user2
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
|
||||
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
|
||||
assert len(result) == 6
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Second", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
context_files = create_context_files(num_files=0, num_images=2)
|
||||
project_files = create_project_files(num_files=0, num_images=2)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
|
||||
)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
context_files = create_context_files(num_files=0, num_images=1)
|
||||
project_files = create_project_files(num_files=0, num_images=1)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg2,
|
||||
user_msg3,
|
||||
]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
|
||||
result = construct_message_history(
|
||||
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
|
||||
tool_response = create_tool_response("tc_1", "tool_response", 20)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
# Budget only allows last user message and messages after + system
|
||||
# (10 + 20 + 20 + 20 = 70 tokens)
|
||||
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg1,
|
||||
user_msg2,
|
||||
]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
|
||||
# keeps [tool_response, assistant_msg1] from history_before_last_user,
|
||||
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=30,
|
||||
)
|
||||
|
||||
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Latest question", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
|
||||
# keeps both assistant_with_tool and tool_response in history_before_last_user.
|
||||
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=45,
|
||||
)
|
||||
|
||||
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Reminder", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history: list[ChatMessageSimple] = []
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, custom_agent, context_files, reminder
|
||||
# Should have: system, custom_agent, project_files, reminder
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == custom_agent
|
||||
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
|
||||
|
||||
simple_chat_history = [assistant_msg, assistant_with_tool]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
with pytest.raises(ValueError, match="No user message found"):
|
||||
construct_message_history(
|
||||
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 50)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=100)
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=100)
|
||||
|
||||
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
|
||||
# But only 200 available
|
||||
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=200,
|
||||
)
|
||||
|
||||
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
context_files = create_context_files()
|
||||
project_files = create_project_files()
|
||||
|
||||
# Budget: 50 tokens
|
||||
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
|
||||
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=50,
|
||||
)
|
||||
|
||||
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=20)
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=20)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Expected order:
|
||||
# system, user1, assistant1, user2, assistant2,
|
||||
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
|
||||
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
|
||||
assert len(result) == 11
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
|
||||
assert result[9] == tool_response # After last user
|
||||
assert result[10] == reminder # At the very end
|
||||
|
||||
def test_context_files_json_format(self) -> None:
|
||||
def test_project_files_json_format(self) -> None:
|
||||
"""Test that project files are formatted correctly as JSON."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Hello", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
project_files=project_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
context_files=create_context_files(),
|
||||
project_files=create_project_files(),
|
||||
available_tokens=available_tokens,
|
||||
token_counter=_simple_token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -106,9 +106,6 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_tenant_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool._connections_held"
|
||||
) as mock_gauge,
|
||||
@@ -117,14 +114,12 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
mock_labels = MagicMock()
|
||||
mock_gauge.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "/api/chat/send-message"
|
||||
mock_tenant_ctx.get.return_value = "tenant_xyz"
|
||||
listeners["checkout"](None, conn_record, None)
|
||||
|
||||
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
|
||||
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
|
||||
assert "_metrics_checkout_time" in conn_record.info
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
|
||||
handler="/api/chat/send-message", engine="sync"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
@@ -149,7 +144,6 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
conn_record = _make_conn_record()
|
||||
conn_record.info["_metrics_endpoint"] = "/api/search"
|
||||
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
|
||||
|
||||
with (
|
||||
@@ -168,9 +162,7 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/search", engine="sync", tenant_id="tenant_abc"
|
||||
)
|
||||
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_labels.dec.assert_called_once()
|
||||
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_hist_labels.observe.assert_called_once()
|
||||
@@ -180,12 +172,11 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
# conn_record.info should be cleaned up
|
||||
assert "_metrics_endpoint" not in conn_record.info
|
||||
assert "_metrics_tenant_id" not in conn_record.info
|
||||
assert "_metrics_checkout_time" not in conn_record.info
|
||||
|
||||
|
||||
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
"""Verify checkin gracefully handles missing endpoint and tenant info."""
|
||||
"""Verify checkin gracefully handles missing endpoint info."""
|
||||
engine = MagicMock()
|
||||
engine.pool = MagicMock()
|
||||
listeners: dict[str, Any] = {}
|
||||
@@ -216,9 +207,7 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="unknown", engine="sync", tenant_id="unknown"
|
||||
)
|
||||
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
|
||||
|
||||
|
||||
# --- setup_postgres_connection_pool_metrics tests ---
|
||||
|
||||
@@ -10,7 +10,6 @@ from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -82,7 +81,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
assert mock_instance.add.call_count == 3
|
||||
mock_instance.add.assert_called_once()
|
||||
mock_instance.instrument.assert_called_once_with(
|
||||
app,
|
||||
latency_lowr_buckets=(
|
||||
@@ -101,56 +100,6 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_per_tenant_callback_increments_with_tenant_id() -> None:
|
||||
"""Verify per-tenant callback reads tenant from contextvar and increments."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "tenant_abc"
|
||||
|
||||
info = _make_info(
|
||||
duration=0.1, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="tenant_abc",
|
||||
method="POST",
|
||||
handler="/api/chat",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_per_tenant_callback_falls_back_to_unknown() -> None:
|
||||
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = None
|
||||
|
||||
info = _make_info(duration=0.1)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="unknown",
|
||||
method="GET",
|
||||
handler="/api/test",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
|
||||
@@ -163,16 +163,3 @@ Add clear comments:
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side
|
||||
effects. Essentially every piece of meaningful logic should exist within some
|
||||
function that has to be explicitly invoked. Acceptable exceptions to this may
|
||||
include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to
|
||||
exist in a file dedicated for manual execution (contains `if __name__ ==
|
||||
"__main__":`) which should not be imported by anything else.
|
||||
- Related to the above, do not conflate Python scripts you intend to run from
|
||||
the command line (contains `if __name__ == "__main__":`) with modules you
|
||||
intend to import from elsewhere. If for some unlikely reason they have to be
|
||||
the same file, any logic specific to executing the file (including imports)
|
||||
should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
@@ -534,10 +534,9 @@ services:
|
||||
required: false
|
||||
|
||||
# Below is needed for the `docker-out-of-docker` execution mode
|
||||
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
|
||||
user: root
|
||||
volumes:
|
||||
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
|
||||
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
|
||||
# privileged: true
|
||||
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.7.3",
|
||||
"pypdf==6.6.2",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -4677,7 +4677,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.3" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.6.2" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -5924,11 +5924,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.7.3"
|
||||
version = "6.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/53/9b/63e767042fc852384dc71e5ff6f990ee4e1b165b1526cf3f9c23a4eebb47/pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2", size = 5303304, upload-time = "2026-02-24T17:23:11.42Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/90/3308a9b8b46c1424181fdf3f4580d2b423c5471425799e7fc62f92d183f4/pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a", size = 331263, upload-time = "2026-02-24T17:23:09.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import "@opal/core/hoverable/styles.css";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context-per-group registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Lazily-created map of group names to React contexts.
|
||||
*
|
||||
* Each group gets its own `React.Context<boolean | null>` so that a
|
||||
* `Hoverable.Item` only re-renders when its *own* group's hover state
|
||||
* changes — not when any unrelated group changes.
|
||||
*
|
||||
* The default value is `null` (no provider found), which lets
|
||||
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
|
||||
* not hovered" and throw when `group` was explicitly specified.
|
||||
*/
|
||||
const contextMap = new Map<string, React.Context<boolean | null>>();
|
||||
|
||||
function getOrCreateContext(group: string): React.Context<boolean | null> {
|
||||
let ctx = contextMap.get(group);
|
||||
if (!ctx) {
|
||||
ctx = createContext<boolean | null>(null);
|
||||
ctx.displayName = `HoverableContext(${group})`;
|
||||
contextMap.set(group, ctx);
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HoverableRootProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group: string;
|
||||
}
|
||||
|
||||
type HoverableItemVariant = "opacity-on-hover";
|
||||
|
||||
interface HoverableItemProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group?: string;
|
||||
variant?: HoverableItemVariant;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableRoot
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hover-tracking container for a named group.
|
||||
*
|
||||
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
|
||||
* provides the hover state via a per-group React context.
|
||||
*
|
||||
* Nesting works because each `Hoverable.Root` creates a **new** context
|
||||
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
|
||||
* reads from the inner provider, not the outer `group="a"` provider.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*/
|
||||
function HoverableRoot({
|
||||
group,
|
||||
children,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(true);
|
||||
consumerMouseEnter?.(e);
|
||||
},
|
||||
[consumerMouseEnter]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(false);
|
||||
consumerMouseLeave?.(e);
|
||||
},
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<GroupContext.Provider value={hovered}>
|
||||
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{children}
|
||||
</div>
|
||||
</GroupContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableItem
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* An element whose visibility is controlled by hover state.
|
||||
*
|
||||
* **Local mode** (`group` omitted): the item handles hover on its own
|
||||
* element via CSS `:hover`. This is the core abstraction.
|
||||
*
|
||||
* **Group mode** (`group` provided): visibility is driven by a matching
|
||||
* `Hoverable.Root` ancestor's hover state via React context. If no
|
||||
* matching Root is found, an error is thrown.
|
||||
*
|
||||
* Uses data-attributes for variant styling (see `styles.css`).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Local mode — hover on the item itself
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
*
|
||||
* // Group mode — hover on the Root reveals the item
|
||||
* <Hoverable.Root group="card">
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*
|
||||
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
|
||||
*/
|
||||
function HoverableItem({
|
||||
group,
|
||||
variant = "opacity-on-hover",
|
||||
children,
|
||||
...props
|
||||
}: HoverableItemProps) {
|
||||
const contextValue = useContext(
|
||||
group ? getOrCreateContext(group) : NOOP_CONTEXT
|
||||
);
|
||||
|
||||
if (group && contextValue === null) {
|
||||
throw new Error(
|
||||
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
|
||||
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
|
||||
);
|
||||
}
|
||||
|
||||
const isLocal = group === undefined;
|
||||
|
||||
return (
|
||||
<div
|
||||
{...props}
|
||||
className={cn("hoverable-item")}
|
||||
data-hoverable-variant={variant}
|
||||
data-hoverable-active={
|
||||
isLocal ? undefined : contextValue ? "true" : undefined
|
||||
}
|
||||
data-hoverable-local={isLocal ? "true" : undefined}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Stable context used when no group is specified (local mode). */
|
||||
const NOOP_CONTEXT = createContext<boolean | null>(null);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compound export
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hoverable compound component for hover-to-reveal patterns.
|
||||
*
|
||||
* Provides two sub-components:
|
||||
*
|
||||
* - `Hoverable.Root` — A container that tracks hover state for a named group
|
||||
* and provides it via React context.
|
||||
*
|
||||
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
|
||||
* applies local CSS `:hover` for the variant effect. When `group` is
|
||||
* specified, it reads hover state from the nearest matching
|
||||
* `Hoverable.Root` — and throws if no matching Root is found.
|
||||
*
|
||||
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
|
||||
* so each group's items only respond to their own root's hover.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* import { Hoverable } from "@opal/core";
|
||||
*
|
||||
* // Group mode — hovering the card reveals the trash icon
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <span>Card content</span>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
*
|
||||
* // Local mode — hovering the item itself reveals it
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* ```
|
||||
*/
|
||||
const Hoverable = {
|
||||
Root: HoverableRoot,
|
||||
Item: HoverableItem,
|
||||
};
|
||||
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
};
|
||||
@@ -1,18 +0,0 @@
|
||||
/* Hoverable — item transitions */
|
||||
.hoverable-item {
|
||||
transition: opacity 200ms ease-in-out;
|
||||
}
|
||||
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Group mode — Root controls visibility via React context */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Local mode — item handles its own :hover */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
@@ -1,11 +1,3 @@
|
||||
/* Hoverable */
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
} from "@opal/core/hoverable/components";
|
||||
|
||||
/* Interactive */
|
||||
export {
|
||||
Interactive,
|
||||
|
||||
20
web/lib/opal/src/icons/handle.tsx
Normal file
20
web/lib/opal/src/icons/handle.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgHandle = ({ size = 16, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={Math.round((size * 3) / 17)}
|
||||
height={size}
|
||||
viewBox="0 0 3 17"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M0.5 0.5V16.5M2.5 0.5V16.5"
|
||||
stroke="currentColor"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgHandle;
|
||||
@@ -77,6 +77,7 @@ export { default as SvgFolderPartialOpen } from "@opal/icons/folder-partial-open
|
||||
export { default as SvgFolderPlus } from "@opal/icons/folder-plus";
|
||||
export { default as SvgGemini } from "@opal/icons/gemini";
|
||||
export { default as SvgGlobe } from "@opal/icons/globe";
|
||||
export { default as SvgHandle } from "@opal/icons/handle";
|
||||
export { default as SvgHardDrive } from "@opal/icons/hard-drive";
|
||||
export { default as SvgHashSmall } from "@opal/icons/hash-small";
|
||||
export { default as SvgHash } from "@opal/icons/hash";
|
||||
@@ -143,6 +144,7 @@ export { default as SvgSlack } from "@opal/icons/slack";
|
||||
export { default as SvgSlash } from "@opal/icons/slash";
|
||||
export { default as SvgSliders } from "@opal/icons/sliders";
|
||||
export { default as SvgSlidersSmall } from "@opal/icons/sliders-small";
|
||||
export { default as SvgSort } from "@opal/icons/sort";
|
||||
export { default as SvgSparkle } from "@opal/icons/sparkle";
|
||||
export { default as SvgStar } from "@opal/icons/star";
|
||||
export { default as SvgStep1 } from "@opal/icons/step1";
|
||||
|
||||
27
web/lib/opal/src/icons/sort.tsx
Normal file
27
web/lib/opal/src/icons/sort.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgSort = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 4.5H10M2 8H7M2 11.5H5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M12 5V12M12 12L14 10M12 12L10 10"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgSort;
|
||||
@@ -10,7 +10,7 @@ export default function Main() {
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgMcp}
|
||||
title="MCP Actions"
|
||||
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your agents."
|
||||
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your assistants."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -10,7 +10,7 @@ export default function Main() {
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgActions}
|
||||
title="OpenAPI Actions"
|
||||
description="Connect OpenAPI servers to add custom actions and tools for your agents."
|
||||
description="Connect OpenAPI servers to add custom actions and tools for your assistants."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -170,7 +170,7 @@ export function PersonasTable({
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgAlertCircle}
|
||||
title="Delete Agent"
|
||||
title="Delete Assistant"
|
||||
onClose={closeDeleteModal}
|
||||
submit={<Button onClick={handleDeletePersona}>Delete</Button>}
|
||||
>
|
||||
@@ -183,15 +183,15 @@ export function PersonasTable({
|
||||
const isDefault = personaToToggleDefault.is_default_persona;
|
||||
|
||||
const title = isDefault
|
||||
? "Remove Featured Agent"
|
||||
: "Set Featured Agent";
|
||||
? "Remove Featured Assistant"
|
||||
: "Set Featured Assistant";
|
||||
const buttonText = isDefault ? "Remove Feature" : "Set as Featured";
|
||||
const text = isDefault
|
||||
? `Are you sure you want to remove the featured status of ${personaToToggleDefault.name}?`
|
||||
: `Are you sure you want to set the featured status of ${personaToToggleDefault.name}?`;
|
||||
const additionalText = isDefault
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured agent will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured agent will make it public and visible to all users. This action cannot be undone.`;
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`;
|
||||
|
||||
return (
|
||||
<ConfirmationModalLayout
|
||||
@@ -217,7 +217,7 @@ export function PersonasTable({
|
||||
"Name",
|
||||
"Description",
|
||||
"Type",
|
||||
"Featured Agent",
|
||||
"Featured Assistant",
|
||||
"Is Visible",
|
||||
"Delete",
|
||||
]}
|
||||
|
||||
@@ -47,8 +47,8 @@ function MainContent({
|
||||
return (
|
||||
<div>
|
||||
<Text className="mb-2">
|
||||
Agents are a way to build custom search/question-answering experiences
|
||||
for different use cases.
|
||||
Assistants are a way to build custom search/question-answering
|
||||
experiences for different use cases.
|
||||
</Text>
|
||||
<Text className="mt-2">They allow you to customize:</Text>
|
||||
<div className="text-sm">
|
||||
@@ -63,21 +63,21 @@ function MainContent({
|
||||
<div>
|
||||
<Separator />
|
||||
|
||||
<Title>Create an Agent</Title>
|
||||
<Title>Create an Assistant</Title>
|
||||
<CreateButton href="/app/agents/create?admin=true">
|
||||
New Agent
|
||||
New Assistant
|
||||
</CreateButton>
|
||||
|
||||
<Separator />
|
||||
|
||||
<Title>Existing Agents</Title>
|
||||
<Title>Existing Assistants</Title>
|
||||
{totalItems > 0 ? (
|
||||
<>
|
||||
<SubLabel>
|
||||
Agents will be displayed as options on the Chat / Search
|
||||
interfaces in the order they are displayed below. Agents marked as
|
||||
hidden will not be displayed. Editable agents are shown at the
|
||||
top.
|
||||
Assistants will be displayed as options on the Chat / Search
|
||||
interfaces in the order they are displayed below. Assistants
|
||||
marked as hidden will not be displayed. Editable assistants are
|
||||
shown at the top.
|
||||
</SubLabel>
|
||||
<PersonasTable
|
||||
personas={customPersonas}
|
||||
@@ -96,21 +96,21 @@ function MainContent({
|
||||
) : (
|
||||
<div className="mt-6 p-8 border border-border rounded-lg bg-background-weak text-center">
|
||||
<Text className="text-lg font-medium mb-2">
|
||||
No custom agents yet
|
||||
No custom assistants yet
|
||||
</Text>
|
||||
<Text className="text-subtle mb-3">
|
||||
Create your first agent to:
|
||||
Create your first assistant to:
|
||||
</Text>
|
||||
<ul className="text-subtle text-sm list-disc text-left inline-block mb-3">
|
||||
<li>Build department-specific knowledge bases</li>
|
||||
<li>Create specialized research agents</li>
|
||||
<li>Create specialized research assistants</li>
|
||||
<li>Set up compliance and policy advisors</li>
|
||||
</ul>
|
||||
<Text className="text-subtle text-sm mb-4">
|
||||
...and so much more!
|
||||
</Text>
|
||||
<CreateButton href="/app/agents/create?admin=true">
|
||||
Create Your First Agent
|
||||
Create Your First Assistant
|
||||
</CreateButton>
|
||||
</div>
|
||||
)}
|
||||
@@ -128,13 +128,13 @@ export default function Page() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgOnyxOctagon} title="Agents" />
|
||||
<AdminPageTitle icon={SvgOnyxOctagon} title="Assistants" />
|
||||
|
||||
{isLoading && <ThreeDotsLoader />}
|
||||
|
||||
{error && (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to load agents"
|
||||
errorTitle="Failed to load assistants"
|
||||
errorMsg={
|
||||
error?.info?.message ||
|
||||
error?.info?.detail ||
|
||||
|
||||
@@ -156,7 +156,7 @@ export const SlackChannelConfigCreationForm = ({
|
||||
is: "assistant",
|
||||
then: (schema) =>
|
||||
schema.required(
|
||||
"An agent is required when using the 'Agent' knowledge source"
|
||||
"A persona is required when using the'Assistant' knowledge source"
|
||||
),
|
||||
}),
|
||||
standard_answer_categories: Yup.array(),
|
||||
|
||||
@@ -224,14 +224,14 @@ export function SlackChannelConfigFormFields({
|
||||
<RadioGroupItemField
|
||||
value="assistant"
|
||||
id="assistant"
|
||||
label="Search Agent"
|
||||
label="Search Assistant"
|
||||
sublabel="Control both the documents and the prompt to use for answering questions"
|
||||
/>
|
||||
<RadioGroupItemField
|
||||
value="non_search_assistant"
|
||||
id="non_search_assistant"
|
||||
label="Non-Search Agent"
|
||||
sublabel="Chat with an agent that does not use documents"
|
||||
label="Non-Search Assistant"
|
||||
sublabel="Chat with an assistant that does not use documents"
|
||||
/>
|
||||
</RadioGroup>
|
||||
</div>
|
||||
@@ -327,15 +327,15 @@ export function SlackChannelConfigFormFields({
|
||||
<div className="mt-4">
|
||||
<SubLabel>
|
||||
<>
|
||||
Select the search-enabled agent OnyxBot will use while answering
|
||||
questions in Slack.
|
||||
Select the search-enabled assistant OnyxBot will use while
|
||||
answering questions in Slack.
|
||||
{syncEnabledAssistants.length > 0 && (
|
||||
<>
|
||||
<br />
|
||||
<span className="text-sm text-text-dark/80">
|
||||
Note: Some of your agents have auto-synced connectors in
|
||||
their document sets. You cannot select these agents as
|
||||
they will not be able to answer questions in Slack.{" "}
|
||||
Note: Some of your assistants have auto-synced connectors
|
||||
in their document sets. You cannot select these assistants
|
||||
as they will not be able to answer questions in Slack.{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
@@ -349,7 +349,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants
|
||||
? "Hide un-selectable "
|
||||
: "View all "}
|
||||
agents
|
||||
assistants
|
||||
</button>
|
||||
</span>
|
||||
</>
|
||||
@@ -367,7 +367,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants && syncEnabledAssistants.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<p className="text-sm text-text-dark/80">
|
||||
Un-selectable agents:
|
||||
Un-selectable assistants:
|
||||
</p>
|
||||
<div className="mb-3 mt-2 flex gap-2 flex-wrap text-sm">
|
||||
{syncEnabledAssistants.map(
|
||||
@@ -394,15 +394,15 @@ export function SlackChannelConfigFormFields({
|
||||
<div className="mt-4">
|
||||
<SubLabel>
|
||||
<>
|
||||
Select the non-search agent OnyxBot will use while answering
|
||||
Select the non-search assistant OnyxBot will use while answering
|
||||
questions in Slack.
|
||||
{syncEnabledAssistants.length > 0 && (
|
||||
<>
|
||||
<br />
|
||||
<span className="text-sm text-text-dark/80">
|
||||
Note: Some of your agents have auto-synced connectors in
|
||||
their document sets. You cannot select these agents as
|
||||
they will not be able to answer questions in Slack.{" "}
|
||||
Note: Some of your assistants have auto-synced connectors
|
||||
in their document sets. You cannot select these assistants
|
||||
as they will not be able to answer questions in Slack.{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
@@ -416,7 +416,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants
|
||||
? "Hide un-selectable "
|
||||
: "View all "}
|
||||
agents
|
||||
assistants
|
||||
</button>
|
||||
</span>
|
||||
</>
|
||||
@@ -524,7 +524,7 @@ export function SlackChannelConfigFormFields({
|
||||
name="is_ephemeral"
|
||||
label="Respond to user in a private (ephemeral) message"
|
||||
tooltip="If set, OnyxBot will respond only to the user in a private (ephemeral) message. If you also
|
||||
chose 'Search' Agent above, selecting this option will make documents that are private to the user
|
||||
chose 'Search' Assistant above, selecting this option will make documents that are private to the user
|
||||
available for their queries."
|
||||
/>
|
||||
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
|
||||
export default function Page() {
|
||||
return <CodeInterpreterPage />;
|
||||
}
|
||||
@@ -39,10 +39,10 @@ export function AdvancedOptions({
|
||||
agents={agents}
|
||||
isLoading={agentsLoading}
|
||||
error={agentsError}
|
||||
label="Agent Whitelist"
|
||||
subtext="Restrict this provider to specific agents."
|
||||
label="Assistant Whitelist"
|
||||
subtext="Restrict this provider to specific assistants."
|
||||
disabled={formikProps.values.is_public}
|
||||
disabledMessage="This LLM Provider is public and available to all agents."
|
||||
disabledMessage="This LLM Provider is public and available to all assistants."
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -299,11 +299,11 @@ export default function Page({ params }: Props) {
|
||||
});
|
||||
refreshGuild();
|
||||
toast.success(
|
||||
personaId ? "Default agent updated" : "Default agent cleared"
|
||||
personaId ? "Default assistant updated" : "Default assistant cleared"
|
||||
);
|
||||
} catch (err) {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to update agent"
|
||||
err instanceof Error ? err.message : "Failed to update assistant"
|
||||
);
|
||||
} finally {
|
||||
setIsUpdating(false);
|
||||
@@ -355,7 +355,7 @@ export default function Page({ params }: Props) {
|
||||
<InputSelect.Trigger placeholder="Select agent" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="default">
|
||||
Default Agent
|
||||
Default Assistant
|
||||
</InputSelect.Item>
|
||||
{personas.map((persona) => (
|
||||
<InputSelect.Item
|
||||
|
||||
@@ -47,8 +47,6 @@ export interface RendererResult {
|
||||
|
||||
// Whether this renderer supports collapsible mode (collapse button shown only when true)
|
||||
supportsCollapsible?: boolean;
|
||||
/** Whether the step should remain collapsible even in single-step timelines */
|
||||
alwaysCollapsible?: boolean;
|
||||
/** Whether the result should be wrapped by timeline UI or rendered as-is */
|
||||
timelineLayout?: TimelineLayout;
|
||||
}
|
||||
|
||||
@@ -50,9 +50,7 @@ export function TimelineStepComposer({
|
||||
header={result.status}
|
||||
isExpanded={result.isExpanded}
|
||||
onToggle={result.onToggle}
|
||||
collapsible={
|
||||
collapsible && (!isSingleStep || !!result.alwaysCollapsible)
|
||||
}
|
||||
collapsible={collapsible && !isSingleStep}
|
||||
supportsCollapsible={result.supportsCollapsible}
|
||||
isLastStep={index === results.length - 1 && isLastStep}
|
||||
isFirstStep={index === 0 && isFirstStep}
|
||||
|
||||
@@ -54,7 +54,7 @@ export function TimelineRow({
|
||||
isHover={isHover}
|
||||
/>
|
||||
)}
|
||||
<div className="flex-1 min-w-0">{children}</div>
|
||||
<div className="flex-1">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
{stdout && (
|
||||
<div className="rounded-md bg-background-neutral-02 p-3">
|
||||
<div className="text-xs font-semibold mb-1 text-text-03">Output:</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01 overflow-x-auto">
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01">
|
||||
{stdout}
|
||||
</pre>
|
||||
</div>
|
||||
@@ -150,7 +150,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
<div className="text-xs font-semibold mb-1 text-status-error-05">
|
||||
Error:
|
||||
</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05 overflow-x-auto">
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05">
|
||||
{stderr}
|
||||
</pre>
|
||||
</div>
|
||||
@@ -181,7 +181,6 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
status,
|
||||
content,
|
||||
supportsCollapsible: true,
|
||||
alwaysCollapsible: true,
|
||||
},
|
||||
]);
|
||||
}
|
||||
@@ -192,7 +191,6 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
icon: SvgTerminal,
|
||||
status,
|
||||
supportsCollapsible: true,
|
||||
alwaysCollapsible: true,
|
||||
content: (
|
||||
<FadingEdgeContainer
|
||||
direction="bottom"
|
||||
|
||||
@@ -427,7 +427,7 @@ export const GroupDisplay = ({
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Agents</h2>
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Assistants</h2>
|
||||
|
||||
<div>
|
||||
{userGroup.document_sets.length > 0 ? (
|
||||
@@ -445,7 +445,7 @@ export const GroupDisplay = ({
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<Text>No Agents in this group...</Text>
|
||||
<Text>No Assistants in this group...</Text>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -152,14 +152,14 @@ export function PersonaMessagesChart({
|
||||
} else if (selectedPersonaId === undefined) {
|
||||
content = (
|
||||
<div className="h-80 text-text-500 flex flex-col">
|
||||
<p className="m-auto">Select an agent to view analytics</p>
|
||||
<p className="m-auto">Select an assistant to view analytics</p>
|
||||
</div>
|
||||
);
|
||||
} else if (!personaMessagesData?.length) {
|
||||
content = (
|
||||
<div className="h-80 text-text-500 flex flex-col">
|
||||
<p className="m-auto">
|
||||
No data found for selected agent in the specified time range
|
||||
No data found for selected assistant in the specified time range
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
@@ -178,9 +178,11 @@ export function PersonaMessagesChart({
|
||||
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Agent Analytics</Title>
|
||||
<Title>Assistant Analytics</Title>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text>Messages and unique users per day for the selected agent</Text>
|
||||
<Text>
|
||||
Messages and unique users per day for the selected assistant
|
||||
</Text>
|
||||
<div className="flex items-center gap-4">
|
||||
<Select
|
||||
value={selectedPersonaId?.toString() ?? ""}
|
||||
@@ -189,14 +191,14 @@ export function PersonaMessagesChart({
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="flex w-full max-w-xs">
|
||||
<SelectValue placeholder="Select an agent to display" />
|
||||
<SelectValue placeholder="Select an assistant to display" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<div className="flex items-center px-2 pb-2 sticky top-0 bg-background border-b">
|
||||
<Search className="h-4 w-4 mr-2 shrink-0 opacity-50" />
|
||||
<input
|
||||
className="flex h-8 w-full rounded-sm bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50"
|
||||
placeholder="Search agents..."
|
||||
placeholder="Search assistants..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
|
||||
@@ -146,7 +146,7 @@ export function AssistantStats({ assistantId }: { assistantId: number }) {
|
||||
return (
|
||||
<Card className="w-full">
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<p className="text-base font-normal text-2xl">Agent Analytics</p>
|
||||
<p className="text-base font-normal text-2xl">Assistant Analytics</p>
|
||||
<AdminDateRangeSelector
|
||||
value={dateRange}
|
||||
onValueChange={setDateRange}
|
||||
|
||||
@@ -72,7 +72,6 @@ export function ClientLayout({
|
||||
enableEnterpriseSS={enableEnterprise}
|
||||
/>
|
||||
<div
|
||||
data-main-container
|
||||
className={cn(
|
||||
"flex flex-1 flex-col min-w-0 min-h-0 overflow-y-auto",
|
||||
!hasOwnLayout && "py-10 px-4 md:px-12"
|
||||
|
||||
@@ -41,14 +41,8 @@ export default function AccessRestricted() {
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const { data: license } = useLicense();
|
||||
|
||||
const hadPreviousLicense = license?.has_license === true;
|
||||
const showRenewalMessage = NEXT_PUBLIC_CLOUD_ENABLED || hadPreviousLicense;
|
||||
|
||||
const initialModalMessage = showRenewalMessage
|
||||
? NEXT_PUBLIC_CLOUD_ENABLED
|
||||
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
|
||||
: "Your access to Onyx has been temporarily suspended due to a lapse in your license."
|
||||
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated.";
|
||||
// Distinguish between "never had a license" vs "license lapsed"
|
||||
const hasLicenseLapsed = license?.has_license === true;
|
||||
|
||||
const handleResubscribe = async () => {
|
||||
setIsLoading(true);
|
||||
@@ -78,7 +72,11 @@ export default function AccessRestricted() {
|
||||
<SvgLock className="stroke-status-error-05 w-[1.5rem] h-[1.5rem]" />
|
||||
</div>
|
||||
|
||||
<Text text03>{initialModalMessage}</Text>
|
||||
<Text text03>
|
||||
{hasLicenseLapsed
|
||||
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
|
||||
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated."}
|
||||
</Text>
|
||||
|
||||
{NEXT_PUBLIC_CLOUD_ENABLED ? (
|
||||
<>
|
||||
@@ -113,7 +111,7 @@ export default function AccessRestricted() {
|
||||
) : (
|
||||
<>
|
||||
<Text text03>
|
||||
{hadPreviousLicense
|
||||
{hasLicenseLapsed
|
||||
? "To reinstate your access and continue using Onyx, please contact your system administrator to renew your license."
|
||||
: "To get started, please contact your system administrator to obtain an Enterprise license."}
|
||||
</Text>
|
||||
@@ -123,8 +121,8 @@ export default function AccessRestricted() {
|
||||
<Link className={linkClassName} href="/admin/billing">
|
||||
Admin Billing
|
||||
</Link>{" "}
|
||||
page to {hadPreviousLicense ? "renew" : "activate"} your license,
|
||||
sign up through Stripe or reach out to{" "}
|
||||
page to {hasLicenseLapsed ? "renew" : "activate"} your license, sign
|
||||
up through Stripe or reach out to{" "}
|
||||
<a className={linkClassName} href="mailto:support@onyx.app">
|
||||
support@onyx.app
|
||||
</a>
|
||||
|
||||
@@ -12,17 +12,17 @@ export default function NoAssistantModal() {
|
||||
return (
|
||||
<Modal open>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header icon={SvgUser} title="No Agent Available" />
|
||||
<Modal.Header icon={SvgUser} title="No Assistant Available" />
|
||||
<Modal.Body>
|
||||
<Text as="p">
|
||||
You currently have no agent configured. To use this feature, you
|
||||
You currently have no assistant configured. To use this feature, you
|
||||
need to take action.
|
||||
</Text>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<Text as="p">
|
||||
As an administrator, you can create a new agent by visiting the
|
||||
admin panel.
|
||||
As an administrator, you can create a new assistant by visiting
|
||||
the admin panel.
|
||||
</Text>
|
||||
<Button className="w-full" href="/admin/assistants">
|
||||
Go to Admin Panel
|
||||
@@ -30,7 +30,8 @@ export default function NoAssistantModal() {
|
||||
</>
|
||||
) : (
|
||||
<Text as="p">
|
||||
Please contact your administrator to configure an agent for you.
|
||||
Please contact your administrator to configure an assistant for
|
||||
you.
|
||||
</Text>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
const HEALTH_ENDPOINT = "/api/admin/code-interpreter/health";
|
||||
const STATUS_ENDPOINT = "/api/admin/code-interpreter";
|
||||
|
||||
interface CodeInterpreterHealth {
|
||||
healthy: boolean;
|
||||
}
|
||||
|
||||
interface CodeInterpreterStatus {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export default function useCodeInterpreter() {
|
||||
const {
|
||||
data: healthData,
|
||||
error: healthError,
|
||||
isLoading: isHealthLoading,
|
||||
mutate: refetchHealth,
|
||||
} = useSWR<CodeInterpreterHealth>(HEALTH_ENDPOINT, errorHandlingFetcher, {
|
||||
refreshInterval: 30000,
|
||||
});
|
||||
|
||||
const {
|
||||
data: statusData,
|
||||
error: statusError,
|
||||
isLoading: isStatusLoading,
|
||||
mutate: refetchStatus,
|
||||
} = useSWR<CodeInterpreterStatus>(STATUS_ENDPOINT, errorHandlingFetcher);
|
||||
|
||||
function refetch() {
|
||||
refetchHealth();
|
||||
refetchStatus();
|
||||
}
|
||||
|
||||
return {
|
||||
isHealthy: healthData?.healthy ?? false,
|
||||
isEnabled: statusData?.enabled ?? false,
|
||||
isLoading: isHealthLoading || isStatusLoading,
|
||||
error: healthError || statusError,
|
||||
refetch,
|
||||
};
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
const UPDATE_ENDPOINT = "/api/admin/code-interpreter";
|
||||
|
||||
interface CodeInterpreterUpdateRequest {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export async function updateCodeInterpreter(
|
||||
request: CodeInterpreterUpdateRequest
|
||||
): Promise<Response> {
|
||||
return fetch(UPDATE_ENDPOINT, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
@@ -131,8 +131,7 @@ export async function updateAgentSharedStatus(
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean | undefined,
|
||||
isPaidEnterpriseFeaturesEnabled: boolean,
|
||||
labelIds?: number[]
|
||||
isPaidEnterpriseFeaturesEnabled: boolean
|
||||
): Promise<null | string> {
|
||||
// MIT versions should not send group_ids - warn if caller provided non-empty groups
|
||||
if (!isPaidEnterpriseFeaturesEnabled && groupIds.length > 0) {
|
||||
@@ -153,7 +152,6 @@ export async function updateAgentSharedStatus(
|
||||
// Only include group_ids for enterprise versions
|
||||
group_ids: isPaidEnterpriseFeaturesEnabled ? groupIds : undefined,
|
||||
is_public: isPublic,
|
||||
label_ids: labelIds,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -168,63 +166,3 @@ export async function updateAgentSharedStatus(
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the labels assigned to an agent via the share endpoint.
|
||||
*
|
||||
* @param agentId - The ID of the agent to update
|
||||
* @param labelIds - Array of label IDs to assign to the agent
|
||||
* @returns null on success, or an error message string on failure
|
||||
*/
|
||||
export async function updateAgentLabels(
|
||||
agentId: number,
|
||||
labelIds: number[]
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/persona/${agentId}/share`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ label_ids: labelIds }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage = (await response.json()).detail || "Unknown error";
|
||||
return errorMessage;
|
||||
} catch (error) {
|
||||
console.error("updateAgentLabels: Network error", error);
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the featured (default) status of an agent.
|
||||
*
|
||||
* @param agentId - The ID of the agent to update
|
||||
* @param isFeatured - Whether the agent should be featured
|
||||
* @returns null on success, or an error message string on failure
|
||||
*/
|
||||
export async function updateAgentFeaturedStatus(
|
||||
agentId: number,
|
||||
isFeatured: boolean
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ is_default_persona: isFeatured }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage = (await response.json()).detail || "Unknown error";
|
||||
return errorMessage;
|
||||
} catch (error) {
|
||||
console.error("updateAgentFeaturedStatus: Network error", error);
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,27 +257,19 @@ export const useLabels = () => {
|
||||
return mutate("/api/persona/labels");
|
||||
};
|
||||
|
||||
const createLabel = async (name: string): Promise<PersonaLabel | null> => {
|
||||
const createLabel = async (name: string) => {
|
||||
const response = await fetch("/api/persona/labels", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
if (response.ok) {
|
||||
const newLabel = await response.json();
|
||||
mutate("/api/persona/labels", [...(labels || []), newLabel], false);
|
||||
}
|
||||
|
||||
const newLabel: PersonaLabel = await response.json();
|
||||
mutate(
|
||||
"/api/persona/labels",
|
||||
(currentLabels: PersonaLabel[] | undefined) => [
|
||||
...(currentLabels || []),
|
||||
newLabel,
|
||||
],
|
||||
false
|
||||
);
|
||||
return newLabel;
|
||||
return response;
|
||||
};
|
||||
|
||||
const updateLabel = async (id: number, name: string) => {
|
||||
|
||||
@@ -1,51 +1,29 @@
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
export interface ChipProps {
|
||||
children?: string;
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
onRemove?: () => void;
|
||||
smallLabel?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple chip/tag component for displaying metadata.
|
||||
* Supports an optional remove button via the `onRemove` prop.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Chip>Tag Name</Chip>
|
||||
* <Chip icon={SvgUser}>John Doe</Chip>
|
||||
* <Chip onRemove={() => removeTag(id)}>Removable</Chip>
|
||||
* ```
|
||||
*/
|
||||
export default function Chip({
|
||||
children,
|
||||
icon: Icon,
|
||||
onRemove,
|
||||
smallLabel = true,
|
||||
}: ChipProps) {
|
||||
export default function Chip({ children, icon: Icon }: ChipProps) {
|
||||
return (
|
||||
<div className="flex items-center gap-1 px-1.5 py-0.5 rounded-08 bg-background-tint-02">
|
||||
{Icon && <Icon size={12} className="text-text-03" />}
|
||||
{children && (
|
||||
<Text figureSmallLabel={smallLabel} text03>
|
||||
<Text figureSmallLabel text03>
|
||||
{children}
|
||||
</Text>
|
||||
)}
|
||||
{onRemove && (
|
||||
<Button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove();
|
||||
}}
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="xs"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,9 +39,9 @@ const useTabsContext = () => {
|
||||
*
|
||||
* Contained (default):
|
||||
* ┌─────────────────────────────────────────────────┐
|
||||
* │ ┌──────────┐ ╔══════════╗ ┌──────────┐ │
|
||||
* │ │ Tab 1 │ ║ Tab 2 ║ │ Tab 3 │ │ ← gray background
|
||||
* │ └──────────┘ ╚══════════╝ └──────────┘ │
|
||||
* │ ┌──────────┐ ╔══════════╗ ┌──────────┐ │
|
||||
* │ │ Tab 1 │ ║ Tab 2 ║ │ Tab 3 │ │ ← gray background
|
||||
* │ └──────────┘ ╚══════════╝ └──────────┘ │
|
||||
* └─────────────────────────────────────────────────┘
|
||||
* ↑ active tab (white bg, shadow)
|
||||
*
|
||||
@@ -49,7 +49,7 @@ const useTabsContext = () => {
|
||||
* Tab 1 Tab 2 Tab 3 [Action]
|
||||
* ╔═════╗
|
||||
* ║ ║ ↑ optional rightContent
|
||||
* ─────────────╨═════╨─────────────────────────────
|
||||
* ────────────╨═════╨─────────────────────────────
|
||||
* ↑ sliding indicator under active tab
|
||||
*
|
||||
* @example
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Chip from "@/refresh-components/Chip";
|
||||
import {
|
||||
innerClasses,
|
||||
textClasses,
|
||||
Variants,
|
||||
wrapperClasses,
|
||||
} from "@/refresh-components/inputs/styles";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
export interface ChipItem {
|
||||
id: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export interface InputChipFieldProps {
|
||||
chips: ChipItem[];
|
||||
onRemoveChip: (id: string) => void;
|
||||
onAdd: (value: string) => void;
|
||||
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
variant?: Variants;
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A tag/chip input field that renders chips inline alongside a text input.
|
||||
*
|
||||
* Pressing Enter adds a chip via `onAdd`. Pressing Backspace on an empty
|
||||
* input removes the last chip. Each chip has a remove button.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <InputChipField
|
||||
* chips={[{ id: "1", label: "Search" }]}
|
||||
* onRemoveChip={(id) => remove(id)}
|
||||
* onAdd={(value) => add(value)}
|
||||
* value={inputValue}
|
||||
* onChange={setInputValue}
|
||||
* placeholder="Add labels..."
|
||||
* icon={SvgTag}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
function InputChipField({
|
||||
chips,
|
||||
onRemoveChip,
|
||||
onAdd,
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
disabled = false,
|
||||
variant = "primary",
|
||||
icon: Icon,
|
||||
className,
|
||||
}: InputChipFieldProps) {
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
function handleKeyDown(e: React.KeyboardEvent<HTMLInputElement>) {
|
||||
if (disabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
const trimmed = value.trim();
|
||||
if (trimmed) {
|
||||
onAdd(trimmed);
|
||||
}
|
||||
}
|
||||
if (e.key === "Backspace" && value === "") {
|
||||
const lastChip = chips[chips.length - 1];
|
||||
if (lastChip) {
|
||||
onRemoveChip(lastChip.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-row items-center flex-wrap gap-1 p-1.5 rounded-08 cursor-text w-full",
|
||||
wrapperClasses[variant],
|
||||
className
|
||||
)}
|
||||
onClick={() => inputRef.current?.focus()}
|
||||
>
|
||||
{Icon && <Icon size={16} className="text-text-04 shrink-0" />}
|
||||
{chips.map((chip) => (
|
||||
<Chip
|
||||
key={chip.id}
|
||||
onRemove={disabled ? undefined : () => onRemoveChip(chip.id)}
|
||||
smallLabel={false}
|
||||
>
|
||||
{chip.label}
|
||||
</Chip>
|
||||
))}
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
disabled={disabled}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={chips.length === 0 ? placeholder : undefined}
|
||||
className={cn(
|
||||
"flex-1 min-w-[80px] h-[1.5rem] bg-transparent p-0.5 focus:outline-none",
|
||||
innerClasses[variant],
|
||||
textClasses[variant]
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default InputChipField;
|
||||
@@ -12,8 +12,6 @@ import {
|
||||
SvgX,
|
||||
SvgXOctagon,
|
||||
} from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
|
||||
const containerClasses = {
|
||||
flash: {
|
||||
default: {
|
||||
@@ -220,7 +218,6 @@ export interface MessageProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||
|
||||
// Features:
|
||||
icon?: boolean;
|
||||
iconComponent?: IconFunctionComponent;
|
||||
actions?: boolean | string;
|
||||
close?: boolean;
|
||||
|
||||
@@ -247,7 +244,6 @@ function MessageInner(
|
||||
description,
|
||||
|
||||
icon = true,
|
||||
iconComponent,
|
||||
actions,
|
||||
close = true,
|
||||
|
||||
@@ -283,9 +279,8 @@ function MessageInner(
|
||||
const textClass = useMemo(() => textClasses[type].text, [type]);
|
||||
const descriptionClass = useMemo(() => textClasses[type].description, [type]);
|
||||
|
||||
const IconComponent = iconComponent
|
||||
? iconComponent
|
||||
: level === "success"
|
||||
const IconComponent =
|
||||
level === "success"
|
||||
? SvgCheckCircle
|
||||
: level === "warning"
|
||||
? SvgAlertTriangle
|
||||
|
||||
@@ -179,7 +179,7 @@ export default function ActionLineItem({
|
||||
)}
|
||||
|
||||
{isSearchToolAndNotInProject && (
|
||||
<Button
|
||||
<IconButton
|
||||
icon={
|
||||
isSearchToolWithNoConnectors ? SvgSettings : SvgChevronRight
|
||||
}
|
||||
@@ -188,8 +188,11 @@ export default function ActionLineItem({
|
||||
router.push("/admin/add-connector");
|
||||
else onSourceManagementOpen?.();
|
||||
})}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
internal
|
||||
className={cn(
|
||||
isSearchToolWithNoConnectors &&
|
||||
"invisible group-hover/LineItem:visible"
|
||||
)}
|
||||
tooltip={
|
||||
isSearchToolWithNoConnectors
|
||||
? "Add Connectors"
|
||||
|
||||
361
web/src/refresh-components/table/Pagination.tsx
Normal file
361
web/src/refresh-components/table/Pagination.tsx
Normal file
@@ -0,0 +1,361 @@
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgChevronLeft, SvgChevronRight } from "@opal/icons";
|
||||
|
||||
type PaginationSize = "lg" | "md" | "sm";
|
||||
|
||||
/**
|
||||
* Minimal page navigation showing `currentPage / totalPages` with prev/next arrows.
|
||||
* Use when you only need simple forward/backward navigation.
|
||||
*/
|
||||
interface SimplePaginationProps {
|
||||
type: "simple";
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `true`, displays the word "pages" after the page indicator. */
|
||||
showUnits?: boolean;
|
||||
/** When `false`, hides the page indicator between the prev/next arrows. Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. */
|
||||
size?: PaginationSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Item-count pagination showing `currentItems of totalItems` with optional page
|
||||
* controls and a "Go to" button. Use inside table footers that need to communicate
|
||||
* how many items the user is viewing.
|
||||
*/
|
||||
interface CountPaginationProps {
|
||||
type: "count";
|
||||
/** Number of items displayed per page. Used to compute the visible range. */
|
||||
pageSize: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `false`, hides the page number between the prev/next arrows (arrows still visible). Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** When `true`, renders a "Go to" button. Requires `onGoTo`. */
|
||||
showGoTo?: boolean;
|
||||
/** Callback invoked when the "Go to" button is clicked. */
|
||||
onGoTo?: () => void;
|
||||
/** When `true`, displays the word "items" after the total count. */
|
||||
showUnits?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. */
|
||||
size?: PaginationSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Numbered page-list pagination with clickable page buttons and ellipsis
|
||||
* truncation for large page counts. Does not support `"sm"` size.
|
||||
*/
|
||||
interface ListPaginationProps {
|
||||
type: "list";
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `false`, hides the page buttons between the prev/next arrows. Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. Only `"lg"` and `"md"` are supported. */
|
||||
size?: Exclude<PaginationSize, "sm">;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discriminated union of all pagination variants.
|
||||
* Use the `type` prop to select between `"simple"`, `"count"`, and `"list"`.
|
||||
*/
|
||||
export type PaginationProps =
|
||||
| SimplePaginationProps
|
||||
| CountPaginationProps
|
||||
| ListPaginationProps;
|
||||
|
||||
function getPageNumbers(currentPage: number, totalPages: number) {
|
||||
const pages: (number | string)[] = [];
|
||||
const maxPagesToShow = 7;
|
||||
|
||||
if (totalPages <= maxPagesToShow) {
|
||||
for (let i = 1; i <= totalPages; i++) {
|
||||
pages.push(i);
|
||||
}
|
||||
} else {
|
||||
pages.push(1);
|
||||
|
||||
let startPage = Math.max(2, currentPage - 1);
|
||||
let endPage = Math.min(totalPages - 1, currentPage + 1);
|
||||
|
||||
if (currentPage <= 3) {
|
||||
endPage = 5;
|
||||
} else if (currentPage >= totalPages - 2) {
|
||||
startPage = totalPages - 4;
|
||||
}
|
||||
|
||||
if (startPage > 2) {
|
||||
pages.push("start-ellipsis");
|
||||
}
|
||||
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
pages.push(i);
|
||||
}
|
||||
|
||||
if (endPage < totalPages - 1) {
|
||||
pages.push("end-ellipsis");
|
||||
}
|
||||
|
||||
pages.push(totalPages);
|
||||
}
|
||||
|
||||
return pages;
|
||||
}
|
||||
|
||||
function sizedTextProps(isSmall: boolean, variant: "mono" | "muted") {
|
||||
if (variant === "mono") {
|
||||
return isSmall ? { secondaryMono: true } : { mainUiMono: true };
|
||||
}
|
||||
return isSmall ? { secondaryBody: true } : { mainUiMuted: true };
|
||||
}
|
||||
|
||||
interface NavButtonsProps {
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
onPageChange: (page: number) => void;
|
||||
size: PaginationSize;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
function NavButtons({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
size,
|
||||
children,
|
||||
}: NavButtonsProps) {
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
icon={SvgChevronLeft}
|
||||
onClick={() => onPageChange(currentPage - 1)}
|
||||
disabled={currentPage <= 1}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
{children}
|
||||
<Button
|
||||
icon={SvgChevronRight}
|
||||
onClick={() => onPageChange(currentPage + 1)}
|
||||
disabled={currentPage >= totalPages}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Table pagination component with three variants: `simple`, `count`, and `list`.
|
||||
* Pass the `type` prop to select the variant, and the component will render the
|
||||
* appropriate UI.
|
||||
*/
|
||||
export default function Pagination(props: PaginationProps) {
|
||||
switch (props.type) {
|
||||
case "simple":
|
||||
return <SimplePaginationInner {...props} />;
|
||||
case "count":
|
||||
return <CountPaginationInner {...props} />;
|
||||
case "list":
|
||||
return <ListPaginationInner {...props} />;
|
||||
}
|
||||
}
|
||||
|
||||
function SimplePaginationInner({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showUnits,
|
||||
showPageIndicator = true,
|
||||
size = "lg",
|
||||
className,
|
||||
}: SimplePaginationProps) {
|
||||
const isSmall = size === "sm";
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentPage}
|
||||
<Text as="span" {...sizedTextProps(isSmall, "muted")} text03>
|
||||
/
|
||||
</Text>
|
||||
{totalPages}
|
||||
</Text>
|
||||
{showUnits && (
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
pages
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</NavButtons>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function CountPaginationInner({
|
||||
pageSize,
|
||||
totalItems,
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showPageIndicator = true,
|
||||
showGoTo,
|
||||
onGoTo,
|
||||
showUnits,
|
||||
size = "lg",
|
||||
className,
|
||||
}: CountPaginationProps) {
|
||||
const isSmall = size === "sm";
|
||||
const rangeStart = (currentPage - 1) * pageSize + 1;
|
||||
const rangeEnd = Math.min(currentPage * pageSize, totalItems);
|
||||
const currentItems = `${rangeStart}~${rangeEnd}`;
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentItems}
|
||||
</Text>
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
of
|
||||
</Text>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
{showUnits && (
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
items
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentPage}
|
||||
</Text>
|
||||
)}
|
||||
</NavButtons>
|
||||
|
||||
{showGoTo && onGoTo && (
|
||||
<Button onClick={onGoTo} size={size} prominence="tertiary">
|
||||
Go to
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ListPaginationInner({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showPageIndicator = true,
|
||||
size = "lg",
|
||||
className,
|
||||
}: ListPaginationProps) {
|
||||
const pageNumbers = getPageNumbers(currentPage, totalPages);
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<div className="flex items-center">
|
||||
{pageNumbers.map((page) => {
|
||||
if (typeof page === "string") {
|
||||
return (
|
||||
<Text
|
||||
key={page}
|
||||
mainUiMuted={size === "lg"}
|
||||
secondaryBody={size === "md"}
|
||||
text03
|
||||
>
|
||||
...
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
const pageNum = page as number;
|
||||
const isActive = pageNum === currentPage;
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={pageNum}
|
||||
onClick={() => onPageChange(pageNum)}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
transient={isActive}
|
||||
icon={({ className: iconClassName }) => (
|
||||
<div
|
||||
className={cn(
|
||||
iconClassName,
|
||||
"flex flex-col justify-center"
|
||||
)}
|
||||
>
|
||||
{size === "lg" ? (
|
||||
<Text
|
||||
mainUiBody={isActive}
|
||||
mainUiMuted={!isActive}
|
||||
text04={isActive}
|
||||
text02={!isActive}
|
||||
>
|
||||
{pageNum}
|
||||
</Text>
|
||||
) : (
|
||||
<Text
|
||||
secondaryAction={isActive}
|
||||
secondaryBody={!isActive}
|
||||
text04={isActive}
|
||||
text02={!isActive}
|
||||
>
|
||||
{pageNum}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</NavButtons>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
116
web/src/refresh-components/table/TableHead.tsx
Normal file
116
web/src/refresh-components/table/TableHead.tsx
Normal file
@@ -0,0 +1,116 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgHandle } from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
|
||||
type SortDirection = "none" | "ascending" | "descending";
|
||||
|
||||
/**
|
||||
* A table header cell with optional sort controls and a resize handle indicator.
|
||||
* Renders as a `<th>` element with Figma-matched typography and spacing.
|
||||
*/
|
||||
interface TableHeadProps {
|
||||
/** Header label content. */
|
||||
children: React.ReactNode;
|
||||
/** Current sort state. When omitted, no sort button is shown. */
|
||||
sorted?: SortDirection;
|
||||
/** Called when the sort button is clicked. Required to show the sort button. */
|
||||
onSort?: () => void;
|
||||
/** When `true`, renders a thin resize handle on the right edge. */
|
||||
resizable?: boolean;
|
||||
/** Override the sort icon for this column. Receives the current sort state and
|
||||
* returns the icon component to render. Falls back to the built-in icons. */
|
||||
icon?: (sorted: SortDirection) => IconFunctionComponent;
|
||||
/** Text alignment for the column. Defaults to `"left"`. */
|
||||
alignment?: "left" | "center" | "right";
|
||||
/** Cell density. `"small"` uses tighter padding for denser layouts. */
|
||||
size?: "regular" | "small";
|
||||
/** Additional classes on the outer `<th>` element. */
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Table header cell primitive. Displays a column label with optional sort
|
||||
* functionality and a resize handle indicator.
|
||||
*/
|
||||
const alignmentThClass = {
|
||||
left: "text-left",
|
||||
center: "text-center",
|
||||
right: "text-right",
|
||||
} as const;
|
||||
|
||||
const alignmentFlexClass = {
|
||||
left: "justify-start",
|
||||
center: "justify-center",
|
||||
right: "justify-end",
|
||||
} as const;
|
||||
|
||||
export default function TableHead({
|
||||
children,
|
||||
sorted,
|
||||
onSort,
|
||||
icon: iconFn,
|
||||
resizable,
|
||||
alignment = "left",
|
||||
size = "regular",
|
||||
className,
|
||||
}: TableHeadProps) {
|
||||
const isSmall = size === "small";
|
||||
const resolvedIcon = iconFn;
|
||||
return (
|
||||
<th
|
||||
className={cn(
|
||||
"group relative",
|
||||
alignmentThClass[alignment],
|
||||
isSmall ? "p-1.5" : "px-2 py-1",
|
||||
"border-b border-transparent group-hover:border-border-03",
|
||||
className
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn("flex items-center gap-1", alignmentFlexClass[alignment])}
|
||||
>
|
||||
<div className={isSmall ? "py-1" : "py-2"}>
|
||||
<Text
|
||||
mainUiAction={!isSmall}
|
||||
secondaryAction={isSmall}
|
||||
text04
|
||||
className="truncate"
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
!isSmall && "py-1.5",
|
||||
"opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
)}
|
||||
>
|
||||
{onSort && resolvedIcon && (
|
||||
<Button
|
||||
icon={resolvedIcon(sorted ?? "none")}
|
||||
onClick={onSort}
|
||||
tooltip="Sort"
|
||||
tooltipSide="top"
|
||||
prominence="internal"
|
||||
size="sm"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{resizable && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute right-0 top-0 flex h-full items-center",
|
||||
"text-border-02",
|
||||
"opacity-0 group-hover:opacity-100",
|
||||
"cursor-col-resize"
|
||||
)}
|
||||
>
|
||||
<SvgHandle size={22} className="stroke-border-02" />
|
||||
</div>
|
||||
)}
|
||||
</th>
|
||||
);
|
||||
}
|
||||
295
web/src/refresh-components/table/footer/Footer.tsx
Normal file
295
web/src/refresh-components/table/footer/Footer.tsx
Normal file
@@ -0,0 +1,295 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Pagination from "@/refresh-components/table/Pagination";
|
||||
import { SvgEye, SvgXCircle } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
type FooterSize = "regular" | "small";
|
||||
type SelectionState = "none" | "partial" | "all";
|
||||
|
||||
/**
|
||||
* Footer mode for tables with selectable rows.
|
||||
* Displays a selection message on the left (with optional view/clear actions)
|
||||
* and a `count`-type pagination on the right.
|
||||
*/
|
||||
interface FooterSelectionModeProps {
|
||||
mode: "selection";
|
||||
/** Whether the table supports selecting multiple rows. */
|
||||
multiSelect: boolean;
|
||||
/** Current selection state: `"none"`, `"partial"`, or `"all"`. */
|
||||
selectionState: SelectionState;
|
||||
/** Number of currently selected items. */
|
||||
selectedCount: number;
|
||||
/** When `true`, renders a qualifier checkbox on the far left. */
|
||||
showQualifier?: boolean;
|
||||
/** Controlled checked state for the qualifier checkbox. */
|
||||
qualifierChecked?: boolean;
|
||||
/** Called when the qualifier checkbox value changes. */
|
||||
onQualifierChange?: (checked: boolean) => void;
|
||||
/** If provided, renders a "View" icon button when items are selected. */
|
||||
onView?: () => void;
|
||||
/** If provided, renders a "Clear" icon button when items are selected. */
|
||||
onClear?: () => void;
|
||||
/** Number of items displayed per page. */
|
||||
pageSize: number;
|
||||
/** First item number in the current page (e.g. `1`). */
|
||||
rangeStart: number;
|
||||
/** Last item number in the current page (e.g. `25`). */
|
||||
rangeEnd: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
|
||||
size?: FooterSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Footer mode for read-only tables (no row selection).
|
||||
* Displays "Showing X~Y of Z" on the left and a `list`-type pagination
|
||||
* on the right.
|
||||
*/
|
||||
interface FooterSummaryModeProps {
|
||||
mode: "summary";
|
||||
/** Number of items displayed per page. */
|
||||
pageSize: number;
|
||||
/** First item number in the current page (e.g. `1`). */
|
||||
rangeStart: number;
|
||||
/** Last item number in the current page (e.g. `25`). */
|
||||
rangeEnd: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** When `true`, renders a qualifier checkbox on the far left. */
|
||||
showQualifier?: boolean;
|
||||
/** Controlled checked state for the qualifier checkbox. */
|
||||
qualifierChecked?: boolean;
|
||||
/** Called when the qualifier checkbox value changes. */
|
||||
onQualifierChange?: (checked: boolean) => void;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
|
||||
size?: FooterSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discriminated union of footer modes.
|
||||
* Use `mode: "selection"` for tables with selectable rows, or
|
||||
* `mode: "summary"` for read-only tables.
|
||||
*/
|
||||
export type FooterProps = FooterSelectionModeProps | FooterSummaryModeProps;
|
||||
|
||||
function getSelectionMessage(
|
||||
state: SelectionState,
|
||||
multi: boolean,
|
||||
count: number
|
||||
): string {
|
||||
if (state === "none") {
|
||||
return multi ? "Select items to continue" : "Select an item to continue";
|
||||
}
|
||||
if (!multi) return "Item selected";
|
||||
return `${count} items selected`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Table footer combining status information on the left with pagination on the
|
||||
* right. Use `mode: "selection"` for tables with selectable rows, or
|
||||
* `mode: "summary"` for read-only tables.
|
||||
*/
|
||||
export default function Footer(props: FooterProps) {
|
||||
const { size = "regular", className } = props;
|
||||
const isSmall = size === "small";
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-full items-center justify-between border-t border-border-01",
|
||||
isSmall ? "min-h-[2.25rem]" : "min-h-[2.75rem]",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{/* Left side */}
|
||||
<div className="flex items-center gap-1 px-1">
|
||||
{props.showQualifier && (
|
||||
<div className="flex items-center px-1">
|
||||
<Checkbox
|
||||
checked={props.qualifierChecked}
|
||||
indeterminate={
|
||||
props.mode === "selection" && props.selectionState === "partial"
|
||||
}
|
||||
onCheckedChange={props.onQualifierChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{props.mode === "selection" ? (
|
||||
<SelectionLeft
|
||||
selectionState={props.selectionState}
|
||||
multiSelect={props.multiSelect}
|
||||
selectedCount={props.selectedCount}
|
||||
onView={props.onView}
|
||||
onClear={props.onClear}
|
||||
isSmall={isSmall}
|
||||
/>
|
||||
) : (
|
||||
<SummaryLeft
|
||||
rangeStart={props.rangeStart}
|
||||
rangeEnd={props.rangeEnd}
|
||||
totalItems={props.totalItems}
|
||||
isSmall={isSmall}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right side */}
|
||||
<div className="flex items-center gap-2 px-1 py-2">
|
||||
{props.mode === "selection" ? (
|
||||
<Pagination
|
||||
type="count"
|
||||
pageSize={props.pageSize}
|
||||
totalItems={props.totalItems}
|
||||
currentPage={props.currentPage}
|
||||
totalPages={props.totalPages}
|
||||
onPageChange={props.onPageChange}
|
||||
showUnits
|
||||
size={isSmall ? "sm" : "md"}
|
||||
/>
|
||||
) : (
|
||||
<Pagination
|
||||
type="list"
|
||||
currentPage={props.currentPage}
|
||||
totalPages={props.totalPages}
|
||||
onPageChange={props.onPageChange}
|
||||
size={isSmall ? "md" : "lg"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface SelectionLeftProps {
|
||||
selectionState: SelectionState;
|
||||
multiSelect: boolean;
|
||||
selectedCount: number;
|
||||
onView?: () => void;
|
||||
onClear?: () => void;
|
||||
isSmall: boolean;
|
||||
}
|
||||
|
||||
function SelectionLeft({
|
||||
selectionState,
|
||||
multiSelect,
|
||||
selectedCount,
|
||||
onView,
|
||||
onClear,
|
||||
isSmall,
|
||||
}: SelectionLeftProps) {
|
||||
const message = getSelectionMessage(
|
||||
selectionState,
|
||||
multiSelect,
|
||||
selectedCount
|
||||
);
|
||||
const hasSelection = selectionState !== "none";
|
||||
|
||||
return (
|
||||
<div className="flex flex-row gap-1 items-center justify-center w-fit flex-shrink-0 h-fit px-1">
|
||||
{isSmall ? (
|
||||
<Text
|
||||
secondaryAction={hasSelection}
|
||||
secondaryBody={!hasSelection}
|
||||
text03
|
||||
>
|
||||
{message}
|
||||
</Text>
|
||||
) : (
|
||||
<Text mainUiBody={hasSelection} mainUiMuted={!hasSelection} text03>
|
||||
{message}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{hasSelection && (
|
||||
<div className="flex flex-row items-center w-fit flex-shrink-0 h-fit">
|
||||
{onView && (
|
||||
<Button
|
||||
icon={SvgEye}
|
||||
onClick={onView}
|
||||
tooltip="View"
|
||||
size="md"
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
{onClear && (
|
||||
<Button
|
||||
icon={SvgXCircle}
|
||||
onClick={onClear}
|
||||
tooltip="Clear selection"
|
||||
size="md"
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface SummaryLeftProps {
|
||||
rangeStart: number;
|
||||
rangeEnd: number;
|
||||
totalItems: number;
|
||||
isSmall: boolean;
|
||||
}
|
||||
|
||||
function SummaryLeft({
|
||||
rangeStart,
|
||||
rangeEnd,
|
||||
totalItems,
|
||||
isSmall,
|
||||
}: SummaryLeftProps) {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
gap={0.25}
|
||||
alignItems="center"
|
||||
width="fit"
|
||||
height="fit"
|
||||
>
|
||||
{isSmall ? (
|
||||
<Text secondaryBody text03>
|
||||
Showing{" "}
|
||||
<Text as="span" secondaryMono text03>
|
||||
{rangeStart}~{rangeEnd}
|
||||
</Text>{" "}
|
||||
of{" "}
|
||||
<Text as="span" secondaryMono text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
</Text>
|
||||
) : (
|
||||
<Text mainUiMuted text03>
|
||||
Showing{" "}
|
||||
<Text as="span" mainUiMono text03>
|
||||
{rangeStart}~{rangeEnd}
|
||||
</Text>{" "}
|
||||
of{" "}
|
||||
<Text as="span" mainUiMono text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
</Text>
|
||||
)}
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
|
||||
import InputTypeInElementField from "@/refresh-components/form/InputTypeInElementField";
|
||||
import InputDatePickerField from "@/refresh-components/form/InputDatePickerField";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { useFormikContext } from "formik";
|
||||
@@ -57,7 +56,6 @@ import {
|
||||
SvgLock,
|
||||
SvgOnyxOctagon,
|
||||
SvgSliders,
|
||||
SvgUsers,
|
||||
SvgTrash,
|
||||
} from "@opal/icons";
|
||||
import CustomAgentAvatar, {
|
||||
@@ -88,7 +86,6 @@ import ShareAgentModal from "@/sections/modals/ShareAgentModal";
|
||||
import AgentKnowledgePane from "@/sections/knowledge/AgentKnowledgePane";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
|
||||
interface AgentIconEditorProps {
|
||||
existingAgent?: FullPersona | null;
|
||||
@@ -453,8 +450,6 @@ export default function AgentEditorPage({
|
||||
const shareAgentModal = useCreateModal();
|
||||
const deleteAgentModal = useCreateModal();
|
||||
const settings = useSettingsContext();
|
||||
const { isAdmin, isCurator } = useUser();
|
||||
const canUpdateFeaturedStatus = isAdmin || isCurator;
|
||||
const vectorDbEnabled = settings?.settings.vector_db_enabled !== false;
|
||||
|
||||
// LLM Model Selection
|
||||
@@ -656,8 +651,6 @@ export default function AgentEditorPage({
|
||||
shared_user_ids: existingAgent?.users?.map((user) => user.id) ?? [],
|
||||
shared_group_ids: existingAgent?.groups ?? [],
|
||||
is_public: existingAgent?.is_public ?? true,
|
||||
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
|
||||
is_default_persona: existingAgent?.is_default_persona ?? false,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object().shape({
|
||||
@@ -819,8 +812,8 @@ export default function AgentEditorPage({
|
||||
uploaded_image_id: values.uploaded_image_id,
|
||||
icon_name: values.icon_name,
|
||||
search_start_date: values.knowledge_cutoff_date || null,
|
||||
label_ids: values.label_ids,
|
||||
is_default_persona: values.is_default_persona,
|
||||
label_ids: null,
|
||||
is_default_persona: false,
|
||||
// display_priority: ...,
|
||||
|
||||
user_file_ids: values.enable_knowledge ? values.user_file_ids : [],
|
||||
@@ -1002,10 +995,6 @@ export default function AgentEditorPage({
|
||||
(fileId: string) =>
|
||||
fileStatusMap.get(fileId) === UserFileStatus.PROCESSING
|
||||
);
|
||||
const isShared =
|
||||
values.is_public ||
|
||||
values.shared_user_ids.length > 0 ||
|
||||
values.shared_group_ids.length > 0;
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -1065,20 +1054,10 @@ export default function AgentEditorPage({
|
||||
userIds={values.shared_user_ids}
|
||||
groupIds={values.shared_group_ids}
|
||||
isPublic={values.is_public}
|
||||
isFeatured={values.is_default_persona}
|
||||
labelIds={values.label_ids}
|
||||
onShare={(
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isFeatured,
|
||||
labelIds
|
||||
) => {
|
||||
onShare={(userIds, groupIds, isPublic) => {
|
||||
setFieldValue("shared_user_ids", userIds);
|
||||
setFieldValue("shared_group_ids", groupIds);
|
||||
setFieldValue("is_public", isPublic);
|
||||
setFieldValue("is_default_persona", isFeatured);
|
||||
setFieldValue("label_ids", labelIds);
|
||||
shareAgentModal.toggle(false);
|
||||
}}
|
||||
/>
|
||||
@@ -1381,36 +1360,17 @@ export default function AgentEditorPage({
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Share This Agent"
|
||||
description="with other users, groups, or everyone in your organization."
|
||||
description="Share this agent with other users, groups, or everyone in your organization."
|
||||
center
|
||||
>
|
||||
<Button
|
||||
secondary
|
||||
leftIcon={isShared ? SvgUsers : SvgLock}
|
||||
leftIcon={SvgLock}
|
||||
onClick={() => shareAgentModal.toggle(true)}
|
||||
>
|
||||
Share
|
||||
</Button>
|
||||
</InputLayouts.Horizontal>
|
||||
{canUpdateFeaturedStatus && (
|
||||
<>
|
||||
<InputLayouts.Horizontal
|
||||
name="is_default_persona"
|
||||
title="Feature This Agent"
|
||||
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
|
||||
>
|
||||
<SwitchField name="is_default_persona" />
|
||||
</InputLayouts.Horizontal>
|
||||
{values.is_default_persona && !isShared && (
|
||||
<Message
|
||||
static
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="This agent is private to you and will only be featured for yourself."
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
|
||||
@@ -425,7 +425,7 @@ export default function AgentsNavigationPage() {
|
||||
>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgOnyxOctagon}
|
||||
title="Agents"
|
||||
title="Agents & Assistants"
|
||||
description="Customize AI behavior and knowledge for you and your team's use cases."
|
||||
rightChildren={
|
||||
<Button
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Card, type CardProps } from "@/refresh-components/cards";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgCheckCircle,
|
||||
SvgRefreshCw,
|
||||
SvgTerminal,
|
||||
SvgUnplug,
|
||||
SvgXOctagon,
|
||||
} from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import useCodeInterpreter from "@/hooks/useCodeInterpreter";
|
||||
import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
interface CodeInterpreterCardProps {
|
||||
variant?: CardProps["variant"];
|
||||
title: string;
|
||||
middleText?: string;
|
||||
strikethrough?: boolean;
|
||||
rightContent: React.ReactNode;
|
||||
}
|
||||
|
||||
function CodeInterpreterCard({
|
||||
variant,
|
||||
title,
|
||||
middleText,
|
||||
strikethrough,
|
||||
rightContent,
|
||||
}: CodeInterpreterCardProps) {
|
||||
return (
|
||||
// TODO (@raunakab): Allow Content to accept strikethrough and middleText
|
||||
<Card variant={variant} padding={0.5}>
|
||||
<ContentAction
|
||||
icon={SvgTerminal}
|
||||
title={middleText ? `${title} ${middleText}` : title}
|
||||
description="Built-in Python runtime"
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
rightChildren={rightContent}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function CheckingStatus() {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.5}
|
||||
>
|
||||
<Text mainUiAction text03>
|
||||
Checking...
|
||||
</Text>
|
||||
<SimpleLoader />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
interface ConnectionStatusProps {
|
||||
healthy: boolean;
|
||||
isLoading: boolean;
|
||||
}
|
||||
|
||||
function ConnectionStatus({ healthy, isLoading }: ConnectionStatusProps) {
|
||||
if (isLoading) {
|
||||
return <CheckingStatus />;
|
||||
}
|
||||
|
||||
const label = healthy ? "Connected" : "Connection Lost";
|
||||
const Icon = healthy ? SvgCheckCircle : SvgXOctagon;
|
||||
const iconColor = healthy ? "text-status-success-05" : "text-status-error-05";
|
||||
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.5}
|
||||
>
|
||||
<Text mainUiAction text03>
|
||||
{label}
|
||||
</Text>
|
||||
<Icon size={16} className={iconColor} />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
interface ActionButtonsProps {
|
||||
onDisconnect: () => void;
|
||||
onRefresh: () => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
function ActionButtons({
|
||||
onDisconnect,
|
||||
onRefresh,
|
||||
disabled,
|
||||
}: ActionButtonsProps) {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.25}
|
||||
>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgUnplug}
|
||||
onClick={onDisconnect}
|
||||
tooltip="Disconnect"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={onRefresh}
|
||||
tooltip="Refresh"
|
||||
disabled={disabled}
|
||||
/>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CodeInterpreterPage() {
|
||||
const { isHealthy, isEnabled, isLoading, refetch } = useCodeInterpreter();
|
||||
const [showDisconnectModal, setShowDisconnectModal] = useState(false);
|
||||
const [isReconnecting, setIsReconnecting] = useState(false);
|
||||
|
||||
async function handleToggle(enabled: boolean) {
|
||||
const action = enabled ? "reconnect" : "disconnect";
|
||||
setIsReconnecting(enabled);
|
||||
try {
|
||||
const response = await updateCodeInterpreter({ enabled });
|
||||
if (!response.ok) {
|
||||
toast.error(`Failed to ${action} Code Interpreter`);
|
||||
return;
|
||||
}
|
||||
setShowDisconnectModal(false);
|
||||
refetch();
|
||||
} finally {
|
||||
setIsReconnecting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgTerminal}
|
||||
title="Code Interpreter"
|
||||
description="Safe and sandboxed Python runtime available to your LLM. See docs for more details."
|
||||
separator
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
{isEnabled || isLoading ? (
|
||||
<CodeInterpreterCard
|
||||
title="Code Interpreter"
|
||||
variant={isHealthy ? "primary" : "secondary"}
|
||||
strikethrough={!isHealthy}
|
||||
rightContent={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
justifyContent="center"
|
||||
alignItems="end"
|
||||
gap={0}
|
||||
padding={0}
|
||||
>
|
||||
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
|
||||
<ActionButtons
|
||||
onDisconnect={() => setShowDisconnectModal(true)}
|
||||
onRefresh={refetch}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<CodeInterpreterCard
|
||||
variant="secondary"
|
||||
title="Code Interpreter"
|
||||
middleText="(Disconnected)"
|
||||
strikethrough={true}
|
||||
rightContent={
|
||||
<Section flexDirection="row" alignItems="center" padding={0.5}>
|
||||
{isReconnecting ? (
|
||||
<CheckingStatus />
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={() => handleToggle(true)}
|
||||
>
|
||||
Reconnect
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
{showDisconnectModal && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title="Disconnect Code Interpreter"
|
||||
onClose={() => setShowDisconnectModal(false)}
|
||||
submit={
|
||||
<Button variant="danger" onClick={() => handleToggle(false)}>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Text as="p" text03>
|
||||
All running sessions connected to{" "}
|
||||
<Text as="span" mainContentEmphasis text03>
|
||||
Code Interpreter
|
||||
</Text>{" "}
|
||||
will stop working. Note that this will not remove any data from your
|
||||
runtime. You can reconnect to this runtime later if needed.
|
||||
</Text>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -11,11 +11,7 @@ import { cn, noProp } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import {
|
||||
checkUserOwnsAssistant,
|
||||
updateAgentSharedStatus,
|
||||
updateAgentFeaturedStatus,
|
||||
} from "@/lib/agents";
|
||||
import { checkUserOwnsAssistant, updateAgentSharedStatus } from "@/lib/agents";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
SvgActions,
|
||||
@@ -47,9 +43,8 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
() => pinnedAgents.some((pinnedAgent) => pinnedAgent.id === agent.id),
|
||||
[agent.id, pinnedAgents]
|
||||
);
|
||||
const { user, isAdmin, isCurator } = useUser();
|
||||
const { user } = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const canUpdateFeaturedStatus = isAdmin || isCurator;
|
||||
const isOwnedByUser = checkUserOwnsAssistant(user, agent);
|
||||
const shareAgentModal = useCreateModal();
|
||||
const agentViewerModal = useCreateModal();
|
||||
@@ -63,49 +58,26 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
route({ agentId: agent.id });
|
||||
}, [pinned, togglePinnedAgent, agent, route]);
|
||||
|
||||
// Handle sharing agent
|
||||
const handleShare = useCallback(
|
||||
async (
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean,
|
||||
isFeatured: boolean,
|
||||
labelIds: number[]
|
||||
) => {
|
||||
const shareError = await updateAgentSharedStatus(
|
||||
async (userIds: string[], groupIds: number[], isPublic: boolean) => {
|
||||
const error = await updateAgentSharedStatus(
|
||||
agent.id,
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isPaidEnterpriseFeaturesEnabled,
|
||||
labelIds
|
||||
isPaidEnterpriseFeaturesEnabled
|
||||
);
|
||||
|
||||
if (shareError) {
|
||||
toast.error(`Failed to share agent: ${shareError}`);
|
||||
return;
|
||||
if (error) {
|
||||
toast.error(`Failed to share agent: ${error}`);
|
||||
} else {
|
||||
// Revalidate the agent data to reflect the changes
|
||||
refreshAgent();
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
if (canUpdateFeaturedStatus) {
|
||||
const featuredError = await updateAgentFeaturedStatus(
|
||||
agent.id,
|
||||
isFeatured
|
||||
);
|
||||
if (featuredError) {
|
||||
toast.error(`Failed to update featured status: ${featuredError}`);
|
||||
refreshAgent();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
refreshAgent();
|
||||
shareAgentModal.toggle(false);
|
||||
},
|
||||
[
|
||||
agent.id,
|
||||
canUpdateFeaturedStatus,
|
||||
isPaidEnterpriseFeaturesEnabled,
|
||||
refreshAgent,
|
||||
]
|
||||
[agent.id, isPaidEnterpriseFeaturesEnabled, refreshAgent]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -116,8 +88,6 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
|
||||
groupIds={fullAgent?.groups ?? []}
|
||||
isPublic={fullAgent?.is_public ?? false}
|
||||
isFeatured={fullAgent?.is_default_persona ?? false}
|
||||
labelIds={fullAgent?.labels?.map((l) => l.id) ?? []}
|
||||
onShare={handleShare}
|
||||
/>
|
||||
</shareAgentModal.Provider>
|
||||
|
||||
@@ -119,7 +119,7 @@ export default function NewTenantModal({
|
||||
: `Your request to join ${tenantInfo.number_of_users} other users of ${APP_DOMAIN} has been approved.`;
|
||||
|
||||
const description = isInvite
|
||||
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current agents, prompts, chats, and connected sources.`
|
||||
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current assistants, prompts, chats, and connected sources.`
|
||||
: `To finish joining your team, please reauthenticate with ${user?.email}.`;
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { useMemo } from "react";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import {
|
||||
SvgLink,
|
||||
SvgOrganization,
|
||||
SvgShare,
|
||||
SvgTag,
|
||||
SvgUsers,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import InputChipField from "@/refresh-components/inputs/InputChipField";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
@@ -29,8 +26,6 @@ import { useUser } from "@/providers/UserProvider";
|
||||
import { Formik, useFormikContext } from "formik";
|
||||
import { useAgent } from "@/hooks/useAgents";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { useLabels } from "@/lib/hooks";
|
||||
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
const YOUR_ORGANIZATION_TAB = "Your Organization";
|
||||
const USERS_AND_GROUPS_TAB = "Users & Groups";
|
||||
@@ -43,8 +38,6 @@ interface ShareAgentFormValues {
|
||||
selectedUserIds: string[];
|
||||
selectedGroupIds: number[];
|
||||
isPublic: boolean;
|
||||
isFeatured: boolean;
|
||||
labelIds: number[];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -60,15 +53,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
useFormikContext<ShareAgentFormValues>();
|
||||
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
|
||||
const { data: groupsData } = useShareableGroups();
|
||||
const { user: currentUser, isAdmin, isCurator } = useUser();
|
||||
const { user: currentUser } = useUser();
|
||||
const { agent: fullAgent } = useAgent(agentId ?? null);
|
||||
const shareAgentModal = useModal();
|
||||
const { labels: allLabels, createLabel } = useLabels();
|
||||
const [labelInputValue, setLabelInputValue] = useState("");
|
||||
|
||||
const acceptedUsers = usersData ?? [];
|
||||
const groups = groupsData ?? [];
|
||||
const canUpdateFeaturedStatus = isAdmin || isCurator;
|
||||
|
||||
// Create options for InputComboBox from all accepted users and groups
|
||||
const comboBoxOptions = useMemo(() => {
|
||||
@@ -147,50 +137,6 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
);
|
||||
}
|
||||
|
||||
const selectedLabels: PersonaLabel[] = useMemo(() => {
|
||||
if (!allLabels) return [];
|
||||
return allLabels.filter((label) => values.labelIds.includes(label.id));
|
||||
}, [allLabels, values.labelIds]);
|
||||
|
||||
function handleRemoveLabel(labelId: number) {
|
||||
setFieldValue(
|
||||
"labelIds",
|
||||
values.labelIds.filter((id) => id !== labelId)
|
||||
);
|
||||
}
|
||||
|
||||
const addLabel = useCallback(
|
||||
async (name: string) => {
|
||||
const trimmed = name.trim();
|
||||
if (!trimmed) return;
|
||||
|
||||
const existing = allLabels?.find(
|
||||
(l) => l.name.toLowerCase() === trimmed.toLowerCase()
|
||||
);
|
||||
if (existing) {
|
||||
if (!values.labelIds.includes(existing.id)) {
|
||||
setFieldValue("labelIds", [...values.labelIds, existing.id]);
|
||||
}
|
||||
} else {
|
||||
const newLabel = await createLabel(trimmed);
|
||||
if (newLabel) {
|
||||
setFieldValue("labelIds", [...values.labelIds, newLabel.id]);
|
||||
}
|
||||
}
|
||||
setLabelInputValue("");
|
||||
},
|
||||
[allLabels, values.labelIds, setFieldValue, createLabel]
|
||||
);
|
||||
|
||||
const chipItems = useMemo(
|
||||
() =>
|
||||
selectedLabels.map((label) => ({
|
||||
id: String(label.id),
|
||||
label: label.name,
|
||||
})),
|
||||
[selectedLabels]
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header icon={SvgShare} title="Share Agent" onClose={handleClose} />
|
||||
@@ -284,56 +230,15 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
</Section>
|
||||
)}
|
||||
</Section>
|
||||
{values.isPublic && (
|
||||
<Section>
|
||||
<Message
|
||||
iconComponent={SvgOrganization}
|
||||
close={false}
|
||||
static
|
||||
className="w-full"
|
||||
text="This agent is public to your organization."
|
||||
description="Everyone in your organization has access to this agent."
|
||||
/>
|
||||
</Section>
|
||||
)}
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={YOUR_ORGANIZATION_TAB} padding={0.5}>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<InputLayouts.Horizontal
|
||||
title="Publish This Agent"
|
||||
description="Make this agent available to everyone in your organization."
|
||||
>
|
||||
<SwitchField name="isPublic" />
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{canUpdateFeaturedStatus && (
|
||||
<>
|
||||
<div className="border-t border-border-02" />
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Feature This Agent"
|
||||
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
|
||||
>
|
||||
<SwitchField name="isFeatured" />
|
||||
</InputLayouts.Horizontal>
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputChipField
|
||||
chips={chipItems}
|
||||
onRemoveChip={(id) => handleRemoveLabel(Number(id))}
|
||||
onAdd={addLabel}
|
||||
value={labelInputValue}
|
||||
onChange={setLabelInputValue}
|
||||
placeholder="Add labels..."
|
||||
icon={SvgTag}
|
||||
/>
|
||||
<Text secondaryBody text04>
|
||||
Add labels and categories to help people better discover this
|
||||
agent.
|
||||
</Text>
|
||||
</Section>
|
||||
<InputLayouts.Horizontal
|
||||
title="Publish This Agent"
|
||||
description="Make this agent available to everyone in your organization."
|
||||
>
|
||||
<SwitchField name="isPublic" />
|
||||
</InputLayouts.Horizontal>
|
||||
</Tabs.Content>
|
||||
</Tabs>
|
||||
</Card>
|
||||
@@ -373,15 +278,7 @@ export interface ShareAgentModalProps {
|
||||
userIds: string[];
|
||||
groupIds: number[];
|
||||
isPublic: boolean;
|
||||
isFeatured: boolean;
|
||||
labelIds: number[];
|
||||
onShare?: (
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean,
|
||||
isFeatured: boolean,
|
||||
labelIds: number[]
|
||||
) => void;
|
||||
onShare?: (userIds: string[], groupIds: number[], isPublic: boolean) => void;
|
||||
}
|
||||
|
||||
export default function ShareAgentModal({
|
||||
@@ -389,8 +286,6 @@ export default function ShareAgentModal({
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isFeatured,
|
||||
labelIds,
|
||||
onShare,
|
||||
}: ShareAgentModalProps) {
|
||||
const shareAgentModal = useModal();
|
||||
@@ -399,18 +294,10 @@ export default function ShareAgentModal({
|
||||
selectedUserIds: userIds,
|
||||
selectedGroupIds: groupIds,
|
||||
isPublic: isPublic,
|
||||
isFeatured: isFeatured,
|
||||
labelIds: labelIds,
|
||||
};
|
||||
|
||||
function handleSubmit(values: ShareAgentFormValues) {
|
||||
onShare?.(
|
||||
values.selectedUserIds,
|
||||
values.selectedGroupIds,
|
||||
values.isPublic,
|
||||
values.isFeatured,
|
||||
values.labelIds
|
||||
);
|
||||
onShare?.(values.selectedUserIds, values.selectedGroupIds, values.isPublic);
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -50,7 +50,6 @@ import {
|
||||
SvgPaintBrush,
|
||||
SvgDiscordMono,
|
||||
SvgWallet,
|
||||
SvgTerminal,
|
||||
} from "@opal/icons";
|
||||
import SvgMcp from "@opal/icons/mcp";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
@@ -92,7 +91,7 @@ const custom_assistants_items = (
|
||||
) => {
|
||||
const items = [
|
||||
{
|
||||
name: "Agents",
|
||||
name: "Assistants",
|
||||
icon: SvgOnyxOctagon,
|
||||
link: "/admin/assistants",
|
||||
},
|
||||
@@ -166,7 +165,7 @@ const collections = (
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: "Custom Agents",
|
||||
name: "Custom Assistants",
|
||||
items: custom_assistants_items(isCurator, enableEnterprise),
|
||||
},
|
||||
...(isCurator && enableEnterprise
|
||||
@@ -208,11 +207,6 @@ const collections = (
|
||||
icon: SvgImage,
|
||||
link: "/admin/configuration/image-generation",
|
||||
},
|
||||
{
|
||||
name: "Code Interpreter",
|
||||
icon: SvgTerminal,
|
||||
link: "/admin/configuration/code-interpreter",
|
||||
},
|
||||
...(!enableCloud && vectorDbEnabled
|
||||
? [
|
||||
{
|
||||
|
||||
@@ -29,12 +29,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
pageTitle: "Add Connector",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - Agents",
|
||||
name: "Custom Assistants - Assistants",
|
||||
path: "assistants",
|
||||
pageTitle: "Agents",
|
||||
pageTitle: "Assistants",
|
||||
options: {
|
||||
paragraphText:
|
||||
"Agents are a way to build custom search/question-answering experiences for different use cases.",
|
||||
"Assistants are a way to build custom search/question-answering experiences for different use cases.",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -52,7 +52,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - Slack Bots",
|
||||
name: "Custom Assistants - Slack Bots",
|
||||
path: "bots",
|
||||
pageTitle: "Slack Bots",
|
||||
options: {
|
||||
@@ -61,7 +61,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - Standard Answers",
|
||||
name: "Custom Assistants - Standard Answers",
|
||||
path: "standard-answer",
|
||||
pageTitle: "Standard Answers",
|
||||
},
|
||||
@@ -101,12 +101,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
pageTitle: "Search Settings",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - MCP Actions",
|
||||
name: "Custom Assistants - MCP Actions",
|
||||
path: "actions/mcp",
|
||||
pageTitle: "MCP Actions",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - OpenAPI Actions",
|
||||
name: "Custom Assistants - OpenAPI Actions",
|
||||
path: "actions/open-api",
|
||||
pageTitle: "OpenAPI Actions",
|
||||
},
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import type { Page } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
const CODE_INTERPRETER_URL = "/admin/configuration/code-interpreter";
|
||||
const API_STATUS_URL = "**/api/admin/code-interpreter";
|
||||
const API_HEALTH_URL = "**/api/admin/code-interpreter/health";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Intercept the status (GET /) and health (GET /health) endpoints with the
|
||||
* given values so the page renders deterministically.
|
||||
*
|
||||
* Also handles PUT requests — by default they succeed (200). Pass
|
||||
* `putStatus` to simulate failures.
|
||||
*/
|
||||
async function mockCodeInterpreterApi(
|
||||
page: Page,
|
||||
opts: { enabled: boolean; healthy: boolean; putStatus?: number }
|
||||
) {
|
||||
const putStatus = opts.putStatus ?? 200;
|
||||
|
||||
await page.route(API_HEALTH_URL, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ healthy: opts.healthy }),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(API_STATUS_URL, async (route) => {
|
||||
if (route.request().method() === "PUT") {
|
||||
await route.fulfill({
|
||||
status: putStatus,
|
||||
contentType: "application/json",
|
||||
body:
|
||||
putStatus >= 400
|
||||
? JSON.stringify({ detail: "Server Error" })
|
||||
: JSON.stringify(null),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ enabled: opts.enabled }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* The disconnect icon button is an icon-only opal Button whose tooltip text
|
||||
* is not exposed as an accessible name. Locate it by finding the first
|
||||
* icon-only button (no label span) inside the card area.
|
||||
*/
|
||||
function getDisconnectIconButton(page: Page) {
|
||||
return page
|
||||
.locator("button:has(.opal-button):not(:has(.opal-button-label))")
|
||||
.first();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Code Interpreter Admin Page", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("page loads with header and description", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
/^Code Interpreter/,
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
await expect(page.getByText("Built-in Python runtime")).toBeVisible();
|
||||
});
|
||||
|
||||
test("shows Connected status when enabled and healthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
|
||||
test("shows Connection Lost when enabled but unhealthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connection Lost")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
});
|
||||
|
||||
test("shows Reconnect button when disabled", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await expect(page.getByText("(Disconnected)")).toBeVisible();
|
||||
});
|
||||
|
||||
test("disconnect flow opens modal and sends PUT request", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Click the disconnect icon button
|
||||
await getDisconnectIconButton(page).click();
|
||||
|
||||
// Modal should appear
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
|
||||
await expect(
|
||||
page.getByText("All running sessions connected to")
|
||||
).toBeVisible();
|
||||
|
||||
// Click the danger Disconnect button in the modal
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Disconnect" }).click();
|
||||
|
||||
// Modal should close after successful disconnect
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).not.toBeVisible(
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
});
|
||||
|
||||
test("disconnect modal can be closed without disconnecting", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Open modal
|
||||
await getDisconnectIconButton(page).click();
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
|
||||
|
||||
// Close modal via Cancel button
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Cancel" }).click();
|
||||
|
||||
// Modal should be gone, page still shows Connected
|
||||
await expect(
|
||||
page.getByText("Disconnect Code Interpreter")
|
||||
).not.toBeVisible();
|
||||
await expect(page.getByText("Connected")).toBeVisible();
|
||||
});
|
||||
|
||||
test("reconnect flow sends PUT with enabled=true", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// Intercept the PUT and verify the payload
|
||||
const putPromise = page.waitForRequest(
|
||||
(req) =>
|
||||
req.url().includes("/api/admin/code-interpreter") &&
|
||||
req.method() === "PUT"
|
||||
);
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
const putReq = await putPromise;
|
||||
expect(putReq.postDataJSON()).toEqual({ enabled: true });
|
||||
});
|
||||
|
||||
test("shows Checking... while reconnect is in progress", async ({ page }) => {
|
||||
// Use a single route handler that delays PUT responses
|
||||
await page.route(API_HEALTH_URL, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ healthy: false }),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(API_STATUS_URL, async (route) => {
|
||||
if (route.request().method() === "PUT") {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify(null),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ enabled: false }),
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
// Should show Checking... while the request is in flight
|
||||
await expect(page.getByText("Checking...")).toBeVisible({ timeout: 3000 });
|
||||
});
|
||||
|
||||
test("shows error toast when disconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
healthy: true,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Open modal and click disconnect
|
||||
await getDisconnectIconButton(page).click();
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Disconnect" }).click();
|
||||
|
||||
// Error toast should appear
|
||||
await expect(
|
||||
page.getByText("Failed to disconnect Code Interpreter")
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("shows error toast when reconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: false,
|
||||
healthy: false,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
// Error toast should appear
|
||||
await expect(
|
||||
page.getByText("Failed to reconnect Code Interpreter")
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Reconnect button should reappear (not stuck in Checking...)
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -46,7 +46,7 @@ test.skip("User changes password and logs in with new password", async ({
|
||||
|
||||
// Verify successful login
|
||||
await expect(page).toHaveURL("http://localhost:3000/app");
|
||||
await expect(page.getByText("Explore Agents")).toBeVisible();
|
||||
await expect(page.getByText("Explore Assistants")).toBeVisible();
|
||||
});
|
||||
|
||||
test.use({ storageState: "admin2_auth.json" });
|
||||
@@ -115,5 +115,5 @@ test.skip("Admin resets own password and logs in with new password", async ({
|
||||
|
||||
// Verify successful login
|
||||
await expect(page).toHaveURL("http://localhost:3000/app");
|
||||
await expect(page.getByText("Explore Agents")).toBeVisible();
|
||||
await expect(page.getByText("Explore Assistants")).toBeVisible();
|
||||
});
|
||||
|
||||
@@ -16,7 +16,7 @@ import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
// Tool-related test selectors now imported from shared utils
|
||||
|
||||
test.describe("Default Agent Tests", () => {
|
||||
test.describe("Default Assistant Tests", () => {
|
||||
let imageGenConfigId: string | null = null;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
@@ -69,7 +69,7 @@ test.describe("Default Agent Tests", () => {
|
||||
});
|
||||
|
||||
test.describe("Greeting Message Display", () => {
|
||||
test("should display greeting message when opening new chat with default agent", async ({
|
||||
test("should display greeting message when opening new chat with default assistant", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Look for greeting message - should be one from the predefined list
|
||||
@@ -95,21 +95,23 @@ test.describe("Default Agent Tests", () => {
|
||||
expect(GREETING_MESSAGES).toContain(greetingAfterReload?.trim());
|
||||
});
|
||||
|
||||
test("greeting should only appear for default agent", async ({ page }) => {
|
||||
// First verify greeting appears for default agent
|
||||
test("greeting should only appear for default assistant", async ({
|
||||
page,
|
||||
}) => {
|
||||
// First verify greeting appears for default assistant
|
||||
const greetingElement = await page.waitForSelector(
|
||||
'[data-testid="onyx-logo"]',
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
expect(greetingElement).toBeTruthy();
|
||||
|
||||
// Create a custom agent to test non-default behavior
|
||||
// Create a custom assistant to test non-default behavior
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Custom Test Agent");
|
||||
await page.locator('input[name="name"]').fill("Custom Test Assistant");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -118,17 +120,17 @@ test.describe("Default Agent Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for agent to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Test Agent");
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Test Assistant");
|
||||
|
||||
// Greeting should NOT appear for custom agent
|
||||
// Greeting should NOT appear for custom assistant
|
||||
const customGreeting = await page.$('[data-testid="onyx-logo"]');
|
||||
expect(customGreeting).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Default Agent Branding", () => {
|
||||
test("should display Onyx logo for default agent", async ({ page }) => {
|
||||
test.describe("Default Assistant Branding", () => {
|
||||
test("should display Onyx logo for default assistant", async ({ page }) => {
|
||||
// Look for Onyx logo
|
||||
const logoElement = await page.waitForSelector(
|
||||
'[data-testid="onyx-logo"]',
|
||||
@@ -136,23 +138,23 @@ test.describe("Default Agent Tests", () => {
|
||||
);
|
||||
expect(logoElement).toBeTruthy();
|
||||
|
||||
// Should NOT show agent name for default agent
|
||||
// Should NOT show assistant name for default assistant
|
||||
const assistantNameElement = await page.$(
|
||||
'[data-testid="assistant-name-display"]'
|
||||
);
|
||||
expect(assistantNameElement).toBeNull();
|
||||
});
|
||||
|
||||
test("custom agents should show name and icon instead of logo", async ({
|
||||
test("custom assistants should show name and icon instead of logo", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom agent
|
||||
// Create a custom assistant
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Custom Agent");
|
||||
await page.locator('input[name="name"]').fill("Custom Assistant");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -161,16 +163,16 @@ test.describe("Default Agent Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for agent to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Agent");
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Assistant");
|
||||
|
||||
// Should show agent name and icon, not Onyx logo
|
||||
// Should show assistant name and icon, not Onyx logo
|
||||
const assistantNameElement = await page.waitForSelector(
|
||||
'[data-testid="assistant-name-display"]',
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
const nameText = await assistantNameElement.textContent();
|
||||
expect(nameText).toContain("Custom Agent");
|
||||
expect(nameText).toContain("Custom Assistant");
|
||||
|
||||
// Onyx logo should NOT be shown
|
||||
const logoElement = await page.$('[data-testid="onyx-logo"]');
|
||||
@@ -179,8 +181,10 @@ test.describe("Default Agent Tests", () => {
|
||||
});
|
||||
|
||||
test.describe("Starter Messages", () => {
|
||||
test("default agent should NOT have starter messages", async ({ page }) => {
|
||||
// Check that starter messages container does not exist for default agent
|
||||
test("default assistant should NOT have starter messages", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Check that starter messages container does not exist for default assistant
|
||||
const starterMessagesContainer = await page.$(
|
||||
'[data-testid="starter-messages"]'
|
||||
);
|
||||
@@ -191,14 +195,18 @@ test.describe("Default Agent Tests", () => {
|
||||
expect(starterButtons.length).toBe(0);
|
||||
});
|
||||
|
||||
test("custom agents should display starter messages", async ({ page }) => {
|
||||
// Create a custom agent with starter messages
|
||||
test("custom assistants should display starter messages", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom assistant with starter messages
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Test Agent with Starters");
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.fill("Test Assistant with Starters");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -211,9 +219,9 @@ test.describe("Default Agent Tests", () => {
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Test Agent with Starters");
|
||||
await verifyAssistantIsChosen(page, "Test Assistant with Starters");
|
||||
|
||||
// Starter messages container might exist but be empty for custom agents
|
||||
// Starter messages container might exist but be empty for custom assistants
|
||||
const starterMessagesContainer = await page.$(
|
||||
'[data-testid="starter-messages"]'
|
||||
);
|
||||
@@ -222,22 +230,24 @@ test.describe("Default Agent Tests", () => {
|
||||
const starterButtons = await page.$$(
|
||||
'[data-testid^="starter-message-"]'
|
||||
);
|
||||
// Custom agent without configured starter messages should have none
|
||||
// Custom assistant without configured starter messages should have none
|
||||
expect(starterButtons.length).toBe(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Agent Selection", () => {
|
||||
test("default agent should be selected for new chats", async ({ page }) => {
|
||||
// Verify the input placeholder indicates default agent (Onyx)
|
||||
test.describe("Assistant Selection", () => {
|
||||
test("default assistant should be selected for new chats", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Verify the input placeholder indicates default assistant (Onyx)
|
||||
await verifyDefaultAssistantIsChosen(page);
|
||||
});
|
||||
|
||||
test("default agent should NOT appear in agent selector", async ({
|
||||
test("default assistant should NOT appear in assistant selector", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Open agent selector
|
||||
// Open assistant selector
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
|
||||
// Wait for modal or assistant list to appear
|
||||
@@ -246,13 +256,13 @@ test.describe("Default Agent Tests", () => {
|
||||
.getByLabel("AgentsPage/new-agent-button")
|
||||
.waitFor({ state: "visible", timeout: 5000 });
|
||||
|
||||
// Look for default agent by name - it should NOT be there
|
||||
// Look for default assistant by name - it should NOT be there
|
||||
const assistantElements = await page.$$('[data-testid^="assistant-"]');
|
||||
const assistantTexts = await Promise.all(
|
||||
assistantElements.map((el) => el.textContent())
|
||||
);
|
||||
|
||||
// Check that the default agent is not in the list
|
||||
// Check that "Assistant" (the default assistant name) is not in the list
|
||||
const hasDefaultAssistant = assistantTexts.some(
|
||||
(text) =>
|
||||
text?.includes("Assistant") &&
|
||||
@@ -265,16 +275,16 @@ test.describe("Default Agent Tests", () => {
|
||||
await page.keyboard.press("Escape");
|
||||
});
|
||||
|
||||
test("should be able to switch from default to custom agent", async ({
|
||||
test("should be able to switch from default to custom assistant", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom agent
|
||||
// Create a custom assistant
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Switch Test Agent");
|
||||
await page.locator('input[name="name"]').fill("Switch Test Assistant");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -283,13 +293,13 @@ test.describe("Default Agent Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify switched to custom agent
|
||||
await verifyAssistantIsChosen(page, "Switch Test Agent");
|
||||
// Verify switched to custom assistant
|
||||
await verifyAssistantIsChosen(page, "Switch Test Assistant");
|
||||
|
||||
// Start new chat to go back to default
|
||||
await startNewChat(page);
|
||||
|
||||
// Should be back to default agent
|
||||
// Should be back to default assistant
|
||||
await verifyDefaultAssistantIsChosen(page);
|
||||
});
|
||||
});
|
||||
@@ -369,7 +379,7 @@ test.describe("Default Agent Tests", () => {
|
||||
);
|
||||
}
|
||||
|
||||
// Enable the tools in default agent config via API
|
||||
// Enable the tools in default assistant config via API
|
||||
// Get current tools to find their IDs
|
||||
const toolsListResp = await page.request.get(
|
||||
"http://localhost:3000/api/tool"
|
||||
@@ -532,7 +542,7 @@ test.describe("Default Agent Tests", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("End-to-End Default Agent Flow", () => {
|
||||
test.describe("End-to-End Default Assistant Flow", () => {
|
||||
let imageGenConfigId: string | null = null;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
@@ -574,7 +584,7 @@ test.describe("End-to-End Default Agent Flow", () => {
|
||||
}
|
||||
});
|
||||
|
||||
test("complete user journey with default agent", async ({ page }) => {
|
||||
test("complete user journey with default assistant", async ({ page }) => {
|
||||
// Clear cookies and log in as a random user
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
@@ -601,7 +611,7 @@ test.describe("End-to-End Default Agent Flow", () => {
|
||||
// Start a new chat
|
||||
await startNewChat(page);
|
||||
|
||||
// Verify we're back to default agent with greeting
|
||||
// Verify we're back to default assistant with greeting
|
||||
await expect(page.locator('[data-testid="onyx-logo"]')).toBeVisible();
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user