mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 06:35:49 +00:00
Compare commits
61 Commits
v3.0.0-clo
...
worktree-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf80211eae | ||
|
|
135385e57b | ||
|
|
f06630bc1b | ||
|
|
4495df98cf | ||
|
|
0124937aa8 | ||
|
|
aec2d24706 | ||
|
|
16ebb55362 | ||
|
|
ab6c11319e | ||
|
|
05f5b96964 | ||
|
|
f525aa175b | ||
|
|
4ba6e5f735 | ||
|
|
992ad3b8d4 | ||
|
|
a6404f8b3e | ||
|
|
efc49c9f6b | ||
|
|
5e447440ea | ||
|
|
78c6ca39b8 | ||
|
|
71a7cf09b3 | ||
|
|
91d30a0156 | ||
|
|
7b30752767 | ||
|
|
4450ecf07c | ||
|
|
0e6b766996 | ||
|
|
12c8cd338b | ||
|
|
ad5688bf65 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a | ||
|
|
28427b3e5f | ||
|
|
3cafcd8a5e | ||
|
|
f2c50b7bb5 | ||
|
|
6b28c6bbfc | ||
|
|
226e801665 | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b |
@@ -9,7 +9,8 @@ inputs:
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
@@ -17,6 +18,14 @@ inputs:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
@@ -59,6 +68,7 @@ runs:
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
@@ -82,6 +92,8 @@ runs:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
@@ -91,11 +103,6 @@ runs:
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -110,10 +117,13 @@ runs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
56
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -603,7 +603,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
@@ -3,33 +3,66 @@ name: Reusable Nightly LLM Provider Chat Tests
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
@@ -38,29 +71,8 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -90,7 +102,6 @@ jobs:
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -119,7 +130,6 @@ jobs:
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -149,11 +159,75 @@ jobs:
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -167,12 +241,14 @@ jobs:
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
@@ -194,7 +270,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add python tool on default
|
||||
|
||||
Revision ID: 57122d037335
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 10:10:40.124925
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57122d037335"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up the PythonTool id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
tool_id = result[0]
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE persona_id = 0 AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"tool_id": result[0]},
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -322,6 +322,7 @@ def list_users(
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -365,6 +366,7 @@ def get_user(
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
@@ -721,6 +723,7 @@ def list_groups(
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -757,6 +760,7 @@ def get_group(
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -76,7 +76,7 @@ def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -163,13 +162,11 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -203,17 +203,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,8 +221,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -242,29 +241,28 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if context_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -275,7 +273,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -289,7 +287,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -445,17 +443,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if context_files and context_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -466,14 +464,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 5. Add last user message (with context images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -547,11 +545,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -559,7 +557,7 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -570,10 +568,10 @@ def _create_project_files_message(
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from project_files
|
||||
# Use pre-calculated token count from context_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=project_files.total_token_count,
|
||||
token_count=context_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -584,7 +582,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
context_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -627,9 +625,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -647,7 +645,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -788,7 +786,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
for uf in user_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -879,46 +864,54 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -23,7 +23,6 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.intune.organizations.organization import Organization # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.site import Site # type: ignore[import-untyped]
|
||||
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped]
|
||||
@@ -872,6 +871,56 @@ class SharepointConnector(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
|
||||
Site URLs look like https://{tenant}.sharepoint.com/sites/... so the
|
||||
tenant domain is the first label of the hostname.
|
||||
"""
|
||||
for site_url in self.sites:
|
||||
try:
|
||||
hostname = urlsplit(site_url.strip()).hostname
|
||||
except ValueError:
|
||||
continue
|
||||
if not hostname:
|
||||
continue
|
||||
tenant = hostname.split(".")[0]
|
||||
if tenant:
|
||||
return tenant
|
||||
logger.warning(f"No tenant domain found from {len(self.sites)} sites")
|
||||
return None
|
||||
|
||||
def _resolve_tenant_domain_from_root_site(self) -> str:
|
||||
"""Resolve tenant domain via GET /v1.0/sites/root which only requires
|
||||
Sites.Read.All (a permission the connector already needs)."""
|
||||
root_site = self.graph_client.sites.root.get().execute_query()
|
||||
hostname = root_site.site_collection.hostname
|
||||
if not hostname:
|
||||
raise ConnectorValidationError(
|
||||
"Could not determine tenant domain from root site"
|
||||
)
|
||||
tenant_domain = hostname.split(".")[0]
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from root site hostname '%s'",
|
||||
tenant_domain,
|
||||
hostname,
|
||||
)
|
||||
return tenant_domain
|
||||
|
||||
def _resolve_tenant_domain(self) -> str:
|
||||
"""Determine the tenant domain, preferring site URLs over a Graph API
|
||||
call to avoid needing extra permissions."""
|
||||
from_sites = self._extract_tenant_domain_from_sites()
|
||||
if from_sites:
|
||||
logger.info(
|
||||
"Resolved tenant domain '%s' from site URLs",
|
||||
from_sites,
|
||||
)
|
||||
return from_sites
|
||||
|
||||
logger.info("No site URLs available; resolving tenant domain from root site")
|
||||
return self._resolve_tenant_domain_from_root_site()
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
@@ -1589,6 +1638,11 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
if not sp_client_id:
|
||||
raise ConnectorValidationError("Client ID is required")
|
||||
if not sp_directory_id:
|
||||
raise ConnectorValidationError("Directory (tenant) ID is required")
|
||||
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
@@ -1641,21 +1695,7 @@ class SharepointConnector(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
raise ConnectorValidationError("No organization found")
|
||||
|
||||
tenant_info: Organization = org[
|
||||
0
|
||||
] # Access first item directly from collection
|
||||
if not tenant_info.verified_domains:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
|
||||
sp_tenant_domain = tenant_info.verified_domains[0].name
|
||||
if not sp_tenant_domain:
|
||||
raise ConnectorValidationError("No verified domains found for tenant")
|
||||
# remove the .onmicrosoft.com part
|
||||
self.sp_tenant_domain = sp_tenant_domain.split(".")[0]
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine.iam_auth import ssl_context
|
||||
from onyx.db.engine.sql_engine import ASYNC_DB_API
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
connect_args["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
cparams["ssl"] = create_ssl_context_if_iam()
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any
|
||||
@@ -48,11 +49,9 @@ def provide_iam_token(
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
@@ -619,7 +619,7 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
provider.default_model_name,
|
||||
provider.default_model_name, # type: ignore[arg-type]
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
|
||||
@@ -2822,13 +2822,17 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# Deprecated: use LLMModelFlow with CHAT flow type instead
|
||||
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
@@ -2879,6 +2883,7 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: use LLMModelFlow with VISION flow type instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
|
||||
@@ -256,9 +256,6 @@ def create_update_persona(
|
||||
try:
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
@@ -335,6 +332,7 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -344,9 +342,7 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -360,6 +356,15 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -965,6 +970,8 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
@@ -1161,9 +1168,6 @@ def update_persona_is_default(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
@@ -57,12 +58,19 @@ def fetch_user_project_ids_for_user_files(
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
for user_file in results
|
||||
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
|
||||
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
|
||||
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
user_file_id_to_project_ids: dict[str, list[int]] = {
|
||||
user_file_id: [] for user_file_id in user_file_ids
|
||||
}
|
||||
for user_file_id, project_id in rows:
|
||||
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
|
||||
|
||||
return user_file_id_to_project_ids
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
|
||||
@@ -139,7 +139,7 @@ def generate_final_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -49,8 +50,11 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -118,8 +122,11 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -107,9 +113,8 @@ class OpenSearchClient:
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -125,6 +130,142 @@ class OpenSearchClient:
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -192,6 +333,38 @@ class OpenSearchClient:
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -610,43 +783,6 @@ class OpenSearchClient:
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -807,48 +943,6 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings.
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
return self._client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._os_client.update_document(
|
||||
self._client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._os_client.search(
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -405,6 +405,7 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -415,14 +416,22 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -105,7 +105,9 @@ class LLMProviderDescriptor(BaseModel):
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=llm_provider_model.name,
|
||||
@@ -184,7 +186,9 @@ class LLMProviderView(LLMProvider):
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
|
||||
27
backend/onyx/server/metrics/per_tenant.py
Normal file
27
backend/onyx/server/metrics/per_tenant.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Per-tenant request counter metric.
|
||||
|
||||
Increments a counter on every request, labelled by tenant, so Grafana can
|
||||
answer "which tenant is generating the most traffic?"
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_requests_by_tenant = Counter(
|
||||
"onyx_api_requests_by_tenant_total",
|
||||
"Total API requests by tenant",
|
||||
["tenant_id", "method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def per_tenant_request_callback(info: Info) -> None:
|
||||
"""Increment per-tenant request counter for every request."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
_requests_by_tenant.labels(
|
||||
tenant_id=tenant_id,
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
@@ -32,6 +32,7 @@ from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -72,7 +73,7 @@ _checkout_timeout_total = Counter(
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine"],
|
||||
["handler", "engine", "tenant_id"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
@@ -163,10 +164,14 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_tenant_id"] = tenant_id
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
@@ -174,9 +179,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
@@ -199,9 +207,12 @@ def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
tenant_id = conn_record.info.pop("_metrics_tenant_id", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
_connections_held.labels(
|
||||
handler=handler, engine=label, tenant_id=tenant_id
|
||||
).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
|
||||
@@ -11,9 +11,11 @@ SQLAlchemy connection pool metrics are registered separately via
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -59,6 +61,15 @@ def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
# Explicitly create the default metrics (http_requests_total,
|
||||
# http_request_duration_seconds, etc.) and add them first. The library
|
||||
# skips creating defaults when ANY custom instrumentations are registered
|
||||
# via .add(), so we must include them ourselves.
|
||||
default_callback = default_metrics(latency_lowr_buckets=_LATENCY_BUCKETS)
|
||||
if default_callback:
|
||||
instrumentator.add(default_callback)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
instrumentator.add(per_tenant_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -32,6 +33,9 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -311,7 +315,14 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ def generate_intermediate_report(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ def run_research_agent_call(
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
context_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# fastapi-users
|
||||
@@ -809,7 +809,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.6.2
|
||||
pypdf==6.7.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
@@ -1216,7 +1217,7 @@ websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
# google-genai
|
||||
werkzeug==3.1.5
|
||||
werkzeug==3.1.6
|
||||
# via sendgrid
|
||||
wrapt==1.17.3
|
||||
# via
|
||||
|
||||
@@ -125,7 +125,7 @@ executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -90,7 +90,7 @@ docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -108,7 +108,7 @@ durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# sentry-sdk
|
||||
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
@@ -521,3 +522,46 @@ def test_sharepoint_connector_hierarchy_nodes(
|
||||
f"Document {doc.semantic_identifier} should have "
|
||||
"parent_hierarchy_raw_node_id set"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_cert_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": os.environ["PERM_SYNC_SHAREPOINT_CLIENT_ID"],
|
||||
"sp_private_key": os.environ["PERM_SYNC_SHAREPOINT_PRIVATE_KEY"],
|
||||
"sp_certificate_password": os.environ[
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
],
|
||||
"sp_directory_id": os.environ["PERM_SYNC_SHAREPOINT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_site_urls(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain from site URLs
|
||||
without calling the /organization endpoint."""
|
||||
site_url = os.environ["SHAREPOINT_SITE"]
|
||||
connector = SharepointConnector(sites=[site_url])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
# The tenant domain should match the first label of the site URL hostname
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
expected = urlsplit(site_url).hostname.split(".")[0] # type: ignore
|
||||
assert connector.sp_tenant_domain == expected
|
||||
|
||||
|
||||
def test_resolve_tenant_domain_from_root_site(
|
||||
sharepoint_cert_credentials: dict[str, str],
|
||||
) -> None:
|
||||
"""Verify that certificate auth resolves the tenant domain via the root
|
||||
site endpoint when no site URLs are configured."""
|
||||
connector = SharepointConnector(sites=[])
|
||||
connector.load_credentials(sharepoint_cert_credentials)
|
||||
|
||||
assert connector.sp_tenant_domain is not None
|
||||
assert len(connector.sp_tenant_domain) > 0
|
||||
|
||||
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
External dependency unit tests for persona file sync.
|
||||
|
||||
Validates that:
|
||||
|
||||
1. The check_for_user_file_project_sync beat task picks up UserFiles with
|
||||
needs_persona_sync=True (not just needs_project_sync).
|
||||
|
||||
2. The process_single_user_file_project_sync worker task reads persona
|
||||
associations from the DB, passes persona_ids to the document index via
|
||||
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
|
||||
|
||||
3. upsert_persona correctly marks affected UserFiles with
|
||||
needs_persona_sync=True when file associations change.
|
||||
|
||||
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
|
||||
since we only need to verify the arguments passed to update_single.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_completed_user_file(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
needs_persona_sync: bool = False,
|
||||
needs_project_sync: bool = False,
|
||||
) -> UserFile:
|
||||
"""Insert a UserFile in COMPLETED status."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
needs_project_sync=needs_project_sync,
|
||||
chunk_count=5,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_test_persona(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user_files: list[UserFile] | None = None,
|
||||
) -> Persona:
|
||||
"""Create a minimal Persona via direct model insert."""
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_files=user_files or [],
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _link_file_to_persona(
|
||||
db_session: Session, persona: Persona, user_file: UserFile
|
||||
) -> None:
|
||||
"""Create the join table row between a persona and a user file."""
|
||||
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
_PATCH_QUEUE_DEPTH = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
".get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on a bound Celery task."""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: check_for_user_file_project_sync picks up persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSweepIncludesPersonaSync:
|
||||
"""The beat task must pick up files needing persona sync, not just project sync."""
|
||||
|
||||
def test_persona_sync_flag_enqueues_task(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
|
||||
user = create_test_user(db_session, "persona_sweep")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) in enqueued_ids
|
||||
|
||||
def test_neither_flag_does_not_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with both flags False is not enqueued."""
|
||||
user = create_test_user(db_session, "no_sync")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) not in enqueued_ids
|
||||
|
||||
def test_both_flags_enqueues_once(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with BOTH flags True is enqueued exactly once."""
|
||||
user = create_test_user(db_session, "both_flags")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
matching_calls = [
|
||||
call
|
||||
for call in mock_app.send_task.call_args_list
|
||||
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
|
||||
]
|
||||
assert len(matching_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: process_single_user_file_project_sync passes persona_ids to index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_GET_SETTINGS = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
|
||||
)
|
||||
_PATCH_GET_INDICES = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
|
||||
)
|
||||
_PATCH_HTTPX_INIT = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
|
||||
)
|
||||
_PATCH_DISABLE_VDB = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
|
||||
)
|
||||
|
||||
|
||||
class TestSyncTaskWritesPersonaIds:
|
||||
"""The sync task reads persona associations and sends them to the index."""
|
||||
|
||||
def test_passes_persona_ids_to_update_single(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After linking a file to a persona, sync sends the persona ID."""
|
||||
user = create_test_user(db_session, "sync_persona")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
mock_doc_index.update_single.assert_called_once()
|
||||
call_args = mock_doc_index.update_single.call_args
|
||||
user_fields: VespaDocumentUserFields = call_args.kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert call_args.args[0] == str(uf.id)
|
||||
|
||||
def test_clears_persona_sync_flag(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a successful sync the needs_persona_sync flag is cleared."""
|
||||
user = create_test_user(db_session, "sync_clear")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with patch(_PATCH_DISABLE_VDB, True):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
def test_passes_both_project_and_persona_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to both a project and a persona gets both IDs."""
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserProject
|
||||
|
||||
user = create_test_user(db_session, "sync_both")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
project = UserProject(user_id=user.id, name="test-project", instructions="")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
|
||||
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert user_fields.user_projects is not None
|
||||
assert persona.id in user_fields.personas
|
||||
assert project.id in user_fields.user_projects
|
||||
|
||||
# Both flags should be cleared
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.needs_project_sync is False
|
||||
|
||||
def test_deleted_persona_excluded_from_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
|
||||
user = create_test_user(db_session, "sync_deleted")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert user_fields.personas is not None
|
||||
assert persona.id not in user_fields.personas
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: upsert_persona marks files for persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpsertPersonaMarksSyncFlag:
|
||||
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
|
||||
|
||||
def test_creating_persona_with_files_marks_sync(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "upsert_create")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
|
||||
def test_updating_persona_files_marks_both_old_and_new(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When file associations change, both the removed and added files are flagged."""
|
||||
user = create_test_user(db_session, "upsert_update")
|
||||
uf_old = _create_completed_user_file(db_session, user)
|
||||
uf_new = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf_old.id],
|
||||
)
|
||||
|
||||
# Clear the flag from creation so we can observe the update
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[uf_new.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf_old)
|
||||
db_session.refresh(uf_new)
|
||||
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
|
||||
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
|
||||
|
||||
def test_removing_all_files_marks_old_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Removing all files from a persona flags the previously associated files."""
|
||||
user = create_test_user(db_session, "upsert_remove")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
|
||||
Mocks the LLM tokenizer and file store since they are not relevant here.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_user_file(db_session: Session, user: User) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
chunk_count=1,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _create_project(db_session: Session, user: User) -> UserProject:
|
||||
project = UserProject(
|
||||
user_id=user.id,
|
||||
name=f"project-{uuid4().hex[:8]}",
|
||||
instructions="",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
|
||||
doc = Document(
|
||||
id=str(user_file.id),
|
||||
source=DocumentSource.USER_FILE,
|
||||
semantic_identifier=user_file.name,
|
||||
sections=[TextSection(text="test chunk content", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
source_document=doc,
|
||||
chunk_id=0,
|
||||
blurb="test chunk",
|
||||
content="test chunk content",
|
||||
source_links={0: ""},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.0] * 768,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_persona_gets_persona_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_persona")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_project_gets_project_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_project")
|
||||
uf = _create_user_file(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_both_gets_both_ids(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_both")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_with_no_associations_gets_empty_lists(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_empty")
|
||||
uf = _create_user_file(db_session, user)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_multiple_personas_all_appear(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to multiple personas should have all their IDs."""
|
||||
user = create_test_user(db_session, "adapter_multi")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona_a = _create_persona(db_session, user)
|
||||
persona_b = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
|
||||
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""External dependency unit tests for OpenSearchClient.
|
||||
"""External dependency unit tests for OpenSearchIndexClient.
|
||||
|
||||
These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for testing with automatic cleanup."""
|
||||
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
client = OpenSearchClient(index_name=test_index_name)
|
||||
client = OpenSearchIndexClient(index_name=test_index_name)
|
||||
|
||||
yield client # Test runs here.
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_client(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
|
||||
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
|
||||
"""Creates a search pipeline for testing with automatic cleanup."""
|
||||
test_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
|
||||
|
||||
|
||||
class TestOpenSearchClient:
|
||||
"""Tests for OpenSearchClient."""
|
||||
"""Tests for OpenSearchIndexClient."""
|
||||
|
||||
def test_create_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index with a real schema."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
|
||||
# Verify index exists.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting an existing index returns True."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
|
||||
assert result is True
|
||||
assert test_client.validate_index(expected_mappings=mappings) is False
|
||||
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting a nonexistent index returns False."""
|
||||
# Under test.
|
||||
# Don't create index, just try to delete.
|
||||
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_index_exists(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests checking if an index exists."""
|
||||
# Precondition.
|
||||
# Index should not exist before creation.
|
||||
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
|
||||
# Index should exist after creation.
|
||||
assert test_client.index_exists() is True
|
||||
|
||||
def test_validate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests validating an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
|
||||
# Should return True after creation.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests put_mapping with same schema is idempotent."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Applying the same mappings again should succeed.
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Index should still be valid.
|
||||
assert test_client.validate_index(expected_mappings=mappings)
|
||||
|
||||
def test_put_mapping_adds_new_field(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping successfully adds new fields to existing index."""
|
||||
# Precondition.
|
||||
# Create index with minimal schema (just required fields).
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Add a new field using put_mapping.
|
||||
updated_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
# New field
|
||||
"new_test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
# Should not raise.
|
||||
test_client.put_mapping(updated_mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Validate the new schema includes the new field.
|
||||
assert test_client.validate_index(expected_mappings=updated_mappings)
|
||||
|
||||
def test_put_mapping_fails_on_type_change(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping fails when trying to change existing field type."""
|
||||
# Precondition.
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Try to change test_field type from keyword to text.
|
||||
conflicting_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "text"}, # Changed from keyword to text
|
||||
},
|
||||
}
|
||||
# Should raise because field type cannot be changed.
|
||||
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
|
||||
test_client.put_mapping(conflicting_mappings)
|
||||
|
||||
def test_put_mapping_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping on non-existent index raises an error."""
|
||||
# Precondition.
|
||||
# Index does not exist yet.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index twice raises an error."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchClient) -> None:
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests creating and deleting a search pipeline."""
|
||||
# Under test and postcondition.
|
||||
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_bulk_index_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests bulk indexing documents."""
|
||||
# Precondition.
|
||||
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
|
||||
test_client.index_document(document=doc, tenant_state=tenant_state)
|
||||
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
|
||||
assert len(keep_ids) == 1
|
||||
|
||||
def test_update_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a document's properties."""
|
||||
# Precondition.
|
||||
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
|
||||
assert updated_doc.public == doc.public
|
||||
|
||||
def test_update_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_search_empty_index(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
|
||||
previous_score = current_score
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_random_search(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests the random search query works."""
|
||||
# Precondition.
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> None:
|
||||
opensearch_client.refresh_index()
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
@@ -283,10 +284,10 @@ def vespa_document_index(
|
||||
def opensearch_client(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for the test tenant."""
|
||||
active = get_active_search_settings(db_session)
|
||||
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
|
||||
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
def test_documents(
|
||||
db_session: Session,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
patch_get_vespa_chunks_page_size: int,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""
|
||||
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.oauth_config import create_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
|
||||
# Should use & instead of ? since URL already has query params
|
||||
assert "foo=bar&" in url or "?foo=bar" in url
|
||||
assert "client_id=custom_client_id" in url
|
||||
|
||||
|
||||
class TestUnwrapSensitiveStr:
|
||||
"""Tests for _unwrap_sensitive_str static method"""
|
||||
|
||||
def test_unwrap_sensitive_str(self) -> None:
|
||||
"""Test that both SensitiveValue and plain str inputs are handled"""
|
||||
# SensitiveValue input
|
||||
sensitive = SensitiveValue[str](
|
||||
encrypted_bytes=b"test_client_id",
|
||||
decrypt_fn=lambda b: b.decode(),
|
||||
)
|
||||
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
|
||||
|
||||
# Plain str input
|
||||
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"
|
||||
|
||||
@@ -76,9 +76,12 @@ class ChatSessionManager:
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
project_id: int | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ScimTokenManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": name},
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
raw_token=data["raw_token"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active(
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestScimToken | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=user_performing_action.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestScimToken(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
token_display=data["token_display"],
|
||||
is_active=data["is_active"],
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -42,6 +42,18 @@ class DATestPAT(BaseModel):
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestScimToken(BaseModel):
|
||||
"""SCIM bearer token model for testing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
raw_token: str | None = None # Only present on initial creation
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
|
||||
|
||||
class DATestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
|
||||
@@ -23,6 +23,8 @@ _ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_API_VERSION = "NIGHTLY_LLM_API_VERSION"
|
||||
_ENV_DEPLOYMENT_NAME = "NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
@@ -34,6 +36,8 @@ class NightlyProviderConfig(BaseModel):
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
deployment_name: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
@@ -45,17 +49,29 @@ def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
def _parse_models_env(env_var: str) -> list[str]:
|
||||
raw_value = os.environ.get(env_var, "").strip()
|
||||
if not raw_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(raw_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_json = None
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
return [str(model).strip() for model in parsed_json if str(model).strip()]
|
||||
|
||||
return [part.strip() for part in raw_value.split(",") if part.strip()]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
model_names = _parse_models_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
api_version = os.environ.get(_ENV_API_VERSION) or None
|
||||
deployment_name = os.environ.get(_ENV_DEPLOYMENT_NAME) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
@@ -74,6 +90,8 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
@@ -95,10 +113,15 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
if config.provider != "ollama_chat" and not (
|
||||
config.api_key or config.custom_config
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
message=(
|
||||
f"{_ENV_API_KEY} or {_ENV_CUSTOM_CONFIG_JSON} is required for "
|
||||
f"provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
@@ -109,6 +132,22 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "azure":
|
||||
if not config.api_base:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_BASE} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
if not config.api_version:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_API_VERSION} is required for provider '{config.provider}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -147,6 +186,8 @@ def _create_provider_payload(
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
@@ -154,6 +195,8 @@ def _create_provider_payload(
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"deployment_name": deployment_name,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
@@ -255,6 +298,8 @@ def _create_and_test_provider_for_model(
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
api_version=config.api_version,
|
||||
deployment_name=config.deployment_name,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
@@ -313,10 +358,21 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
failures: list[str] = []
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
try:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
failures.append(
|
||||
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
|
||||
)
|
||||
|
||||
if failures:
|
||||
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
|
||||
|
||||
@@ -72,6 +72,9 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"read_file" in tool_names
|
||||
), "Default assistant should have FileReaderTool attached"
|
||||
assert (
|
||||
"python" in tool_names
|
||||
), "Default assistant should have PythonTool attached"
|
||||
|
||||
# Also verify by display names for clarity
|
||||
assert (
|
||||
@@ -86,8 +89,11 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert (
|
||||
"File Reader" in tool_display_names
|
||||
), "Default assistant should have File Reader tool"
|
||||
|
||||
# Should have exactly 5 tools
|
||||
assert (
|
||||
len(tool_associations) == 5
|
||||
), f"Default assistant should have exactly 5 tools attached, got {len(tool_associations)}"
|
||||
"Code Interpreter" in tool_display_names
|
||||
), "Default assistant should have Code Interpreter tool"
|
||||
|
||||
# Should have exactly 6 tools
|
||||
assert (
|
||||
len(tool_associations) == 6
|
||||
), f"Default assistant should have exactly 6 tools attached, got {len(tool_associations)}"
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Integration tests for the unified persona file context flow.
|
||||
|
||||
End-to-end tests that verify:
|
||||
1. Files can be uploaded and attached to a persona via API.
|
||||
2. The persona correctly reports its attached files.
|
||||
3. A chat session with a file-bearing persona processes without error.
|
||||
4. Precedence: custom persona files take priority over project files when
|
||||
the chat session is inside a project.
|
||||
|
||||
These tests run against a real Onyx deployment (all services running).
|
||||
File processing is asynchronous, so we poll the file status endpoint
|
||||
until files reach COMPLETED before chatting.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILE_PROCESSING_POLL_INTERVAL = 2
|
||||
|
||||
|
||||
def _poll_file_statuses(
|
||||
user_file_ids: list[str],
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus = UserFileStatus.COMPLETED,
|
||||
timeout: int = MAX_DELAY,
|
||||
) -> None:
|
||||
"""Block until all files reach the target status or timeout expires."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/statuses",
|
||||
json={"file_ids": user_file_ids},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
statuses = response.json()
|
||||
if all(f["status"] == target_status.value for f in statuses):
|
||||
return
|
||||
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
|
||||
raise TimeoutError(
|
||||
f"Files {user_file_ids} did not reach {target_status.value} "
|
||||
f"within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_with_files_chat_no_error(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Upload files, attach them to a persona, wait for processing,
|
||||
then send a chat message. Verify no error is returned."""
|
||||
|
||||
# Upload files (creates UserFile records)
|
||||
text_file = create_test_text_file(
|
||||
"The secret project codename is NIGHTINGALE. "
|
||||
"It was started in 2024 by the Advanced Research division."
|
||||
)
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("nightingale_brief.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
assert len(file_descriptors) == 1
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"]
|
||||
assert user_file_id is not None
|
||||
|
||||
# Wait for file processing
|
||||
_poll_file_statuses([user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with the file attached
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Nightingale Agent",
|
||||
description="Agent with secret file",
|
||||
system_prompt="You are a helpful assistant with access to uploaded files.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
# Verify persona has the file
|
||||
persona_snapshots = PersonaManager.get_one(persona.id, admin_user)
|
||||
assert len(persona_snapshots) == 1
|
||||
assert user_file_id in persona_snapshots[0].user_file_ids
|
||||
|
||||
# Chat with the persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test persona file context",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret project codename?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
assert len(response.full_message) > 0, "Response should not be empty"
|
||||
|
||||
|
||||
def test_persona_without_files_still_works(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A persona with no attached files should still chat normally."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Blank Agent",
|
||||
description="No files attached",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test blank persona",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="Hello, how are you?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_persona_files_override_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a custom persona (with its own files) is used inside a project,
|
||||
the persona's files take precedence — the project's files are invisible.
|
||||
|
||||
We verify this by putting different content in project vs persona files
|
||||
and checking which content the model responds with."""
|
||||
|
||||
# Upload persona file
|
||||
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
|
||||
persona_fds, err1 = FileManager.upload_files(
|
||||
files=[("persona_secret.txt", persona_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not err1
|
||||
persona_user_file_id = persona_fds[0]["user_file_id"]
|
||||
assert persona_user_file_id is not None
|
||||
# Create a project and upload project files
|
||||
project = ProjectManager.create(
|
||||
name="Precedence Test Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_secret.txt", b"The project's secret word is FLAMINGO."),
|
||||
]
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for both persona and project file processing
|
||||
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with persona file
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Override Agent",
|
||||
description="Persona with its own files",
|
||||
system_prompt="You are a helpful assistant. Answer using the files.",
|
||||
user_file_ids=[persona_user_file_id],
|
||||
)
|
||||
|
||||
# Create chat session inside the project but using the custom persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret word?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
# The persona's file should be what the model sees, not the project's
|
||||
message_lower = response.full_message.lower()
|
||||
assert "albatross" in message_lower, (
|
||||
"Response should reference the persona file's secret word (ALBATROSS), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_default_persona_in_project_uses_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the default persona (id=0) is used inside a project,
|
||||
the project's files should be used for context."""
|
||||
project = ProjectManager.create(
|
||||
name="Default Persona Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_info.txt", b"The project mascot is a PANGOLIN."),
|
||||
]
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(upload_result.user_files) == 1
|
||||
|
||||
# Wait for project file processing
|
||||
project_file_id = str(upload_result.user_files[0].id)
|
||||
_poll_file_statuses([project_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create chat session inside project using default persona (id=0)
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project mascot?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert "pangolin" in response.full_message.lower(), (
|
||||
"Response should reference the project file content (PANGOLIN), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_persona_no_files_in_project_ignores_project(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A custom persona with NO files, used inside a project with files,
|
||||
should NOT see the project's files. The project is purely organizational.
|
||||
|
||||
We verify by asking about content only in the project file and checking
|
||||
the model does NOT reference it."""
|
||||
|
||||
project = ProjectManager.create(
|
||||
name="Ignored Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(project_upload_result.user_files) == 1
|
||||
project_user_file_id = str(project_upload_result.user_files[0].id)
|
||||
|
||||
# Wait for project file processing
|
||||
_poll_file_statuses([project_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Custom persona with no files
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="No Files Agent",
|
||||
description="No files, project is irrelevant",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant. If you do not have information "
|
||||
"to answer a question, say 'I do not have that information.'"
|
||||
),
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the project secret?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
assert "capybara" not in response.full_message.lower(), (
|
||||
"Response should NOT reference the project file content (CAPYBARA) "
|
||||
"because the custom persona has no files and should not inherit "
|
||||
f"project files, but got: {response.full_message}"
|
||||
)
|
||||
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
166
backend/tests/integration/tests/scim/test_scim_tokens.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Integration tests for SCIM token management.
|
||||
|
||||
Covers the admin token API and SCIM bearer-token authentication:
|
||||
1. Token lifecycle: create, retrieve metadata, use for SCIM requests
|
||||
2. Token rotation: creating a new token revokes previous tokens
|
||||
3. Revoked tokens are rejected by SCIM endpoints
|
||||
4. Non-admin users cannot manage SCIM tokens
|
||||
5. SCIM requests without a token are rejected
|
||||
6. Service discovery endpoints work without authentication
|
||||
7. last_used_at is updated after a SCIM request
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
"""Create token → retrieve metadata → use for SCIM request."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Test SCIM Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert token.raw_token is not None
|
||||
assert token.raw_token.startswith("onyx_scim_")
|
||||
assert token.is_active is True
|
||||
assert "****" in token.token_display
|
||||
|
||||
# GET returns the same metadata but raw_token is None because the
|
||||
# server only reveals the raw token once at creation time (it stores
|
||||
# only the SHA-256 hash).
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
assert body["totalResults"] >= 0
|
||||
|
||||
|
||||
def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
"""Creating a new token automatically revokes the previous one."""
|
||||
first = ScimTokenManager.create(
|
||||
name="First Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
second = ScimTokenManager.create(
|
||||
name="Second Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert second.raw_token is not None
|
||||
|
||||
# Active token should now be the second one
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to create a SCIM token."""
|
||||
basic_user = UserManager.create(name="scim_basic_user")
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
json={"name": "Should Fail"},
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_non_admin_cannot_get_token(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Non-admin users get 403 when trying to retrieve SCIM token metadata."""
|
||||
basic_user = UserManager.create(name="scim_basic_user2")
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=basic_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None:
|
||||
"""GET active token returns 404 when no token exists."""
|
||||
# new_admin_user depends on the reset fixture, ensuring a clean DB
|
||||
# with no active SCIM tokens.
|
||||
active = ScimTokenManager.get_active(user_performing_action=new_admin_user)
|
||||
assert active is None
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/enterprise-settings/scim/token",
|
||||
headers=new_admin_user.headers,
|
||||
timeout=60,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_service_discovery_no_auth_required(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
def test_last_used_at_updated_after_scim_request(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""last_used_at timestamp is updated after using the token."""
|
||||
token = ScimTokenManager.create(
|
||||
name="Last Used Token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert token.raw_token is not None
|
||||
|
||||
active = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active is not None
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
assert active_after is not None
|
||||
assert active_after.last_used_at is not None
|
||||
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
426
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tests for the unified context file extraction logic (Phase 5).
|
||||
|
||||
Covers:
|
||||
- resolve_context_user_files: precedence rule (custom persona supersedes project)
|
||||
- extract_context_files: all-or-nothing context window fit check
|
||||
- Search filter / search_usage determination in the caller
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.process_message import determine_search_params
|
||||
from onyx.chat.process_message import extract_context_files
|
||||
from onyx.chat.process_message import resolve_context_user_files
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
token_count: int = 100,
|
||||
name: str = "file.txt",
|
||||
file_id: str | None = None,
|
||||
) -> UserFile:
|
||||
file_uuid = UUID(file_id) if file_id else uuid4()
|
||||
return UserFile(
|
||||
id=file_uuid,
|
||||
file_id=str(file_uuid),
|
||||
name=name,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(
|
||||
persona_id: int,
|
||||
user_files: list | None = None,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.user_files = user_files or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_in_memory_file(
|
||||
file_id: str,
|
||||
content: str = "hello world",
|
||||
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
|
||||
filename: str = "file.txt",
|
||||
) -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=content.encode("utf-8"),
|
||||
file_type=file_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# resolve_context_user_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveContextUserFiles:
|
||||
"""Precedence rule: custom persona fully supersedes project."""
|
||||
|
||||
def test_custom_persona_with_files_returns_persona_files(self) -> None:
|
||||
persona_files = [_make_user_file(), _make_user_file()]
|
||||
persona = _make_persona(persona_id=42, user_files=persona_files)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == persona_files
|
||||
|
||||
def test_custom_persona_without_files_returns_empty(self) -> None:
|
||||
"""Custom persona with no files should NOT fall through to project."""
|
||||
persona = _make_persona(persona_id=42, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_custom_persona_none_files_returns_empty(self) -> None:
|
||||
"""Custom persona with user_files=None should NOT fall through."""
|
||||
persona = _make_persona(persona_id=42, user_files=None)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_default_persona_in_project_returns_project_files(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
project_files = [_make_user_file(), _make_user_file()]
|
||||
mock_get_files.return_value = project_files
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
user_id = uuid4()
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
assert result == project_files
|
||||
mock_get_files.assert_called_once_with(
|
||||
project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
def test_default_persona_no_project_returns_empty(self) -> None:
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_custom_persona_without_files_ignores_project(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
"""Even with a project_id, custom persona means project is invisible."""
|
||||
persona = _make_persona(persona_id=7, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_get_files.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# extract_context_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractContextFiles:
|
||||
"""All-or-nothing context window fit check."""
|
||||
|
||||
def test_empty_user_files_returns_empty(self) -> None:
|
||||
db_session = MagicMock()
|
||||
result = extract_context_files(
|
||||
user_files=[],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.uncapped_token_count is None
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=100, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=file_id, content="file content")
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["file content"]
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.total_token_count == 100
|
||||
assert len(result.file_metadata) == 1
|
||||
assert result.file_metadata[0].file_id == file_id
|
||||
|
||||
def test_files_overflow_context_not_loaded(self) -> None:
|
||||
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
|
||||
uf = _make_user_file(token_count=7000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.uncapped_token_count == 7000
|
||||
assert result.total_token_count == 0
|
||||
|
||||
def test_overflow_boundary_exact(self) -> None:
|
||||
"""Token count exactly at the 60% boundary should trigger overflow."""
|
||||
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
|
||||
uf = _make_user_file(token_count=6000)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
|
||||
"""Token count just under the 60% boundary should load files."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=5999, file_id=file_id)
|
||||
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.file_texts == ["data"]
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
|
||||
"""Multiple small files that individually fit but collectively overflow."""
|
||||
files = [_make_user_file(token_count=2500) for _ in range(3)]
|
||||
# 3 * 2500 = 7500 > 6000 threshold
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=files,
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
|
||||
"""Reserved tokens shrink the available window."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=3000, file_id=file_id)
|
||||
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=5000,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=50, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert len(result.image_files) == 1
|
||||
assert result.image_files[0].file_id == file_id
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that determine_search_params correctly resolves
|
||||
search_project_id, search_persona_id, and search_usage based on
|
||||
the extraction result and the precedence rule.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_context(
|
||||
use_as_search_filter: bool = False,
|
||||
file_texts: list[str] | None = None,
|
||||
uncapped_token_count: int | None = None,
|
||||
) -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts or [],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=uncapped_token_count,
|
||||
)
|
||||
|
||||
def test_custom_persona_files_fit_no_filter(self) -> None:
|
||||
"""Custom persona, files fit → no search filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
"""Custom persona, files overflow → persona_id filter, AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(use_as_search_filter=True),
|
||||
)
|
||||
assert result.search_persona_id == 42
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
"""Custom persona (no files) in project → nothing leaks from project."""
|
||||
result = determine_search_params(
|
||||
persona_id=42,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
"""Default persona, project files fit → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
"""Default persona, project files overflow → ENABLED + project_id filter."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(
|
||||
use_as_search_filter=True,
|
||||
uncapped_token_count=7000,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id == 99
|
||||
assert result.search_persona_id is None
|
||||
assert result.search_usage == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
"""Default persona, no project → AUTO."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=None,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
"""Default persona in project with no files → DISABLED."""
|
||||
result = determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -74,20 +74,20 @@ def create_tool_response(
|
||||
)
|
||||
|
||||
|
||||
def create_project_files(
|
||||
def create_context_files(
|
||||
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Helper to create ExtractedProjectFiles for testing."""
|
||||
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
project_file_metadata = [
|
||||
ProjectFileMetadata(
|
||||
) -> ExtractedContextFiles:
|
||||
"""Helper to create ExtractedContextFiles for testing."""
|
||||
file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
file_metadata = [
|
||||
ContextFileMetadata(
|
||||
file_id=f"file_{i}",
|
||||
filename=f"file_{i}.txt",
|
||||
file_content=f"Project file {i} content",
|
||||
)
|
||||
for i in range(num_files)
|
||||
]
|
||||
project_image_files = [
|
||||
image_files = [
|
||||
ChatLoadedFile(
|
||||
file_id=f"image_{i}",
|
||||
content=b"",
|
||||
@@ -98,13 +98,13 @@ def create_project_files(
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=False,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=num_files * tokens_per_file,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=num_files * tokens_per_file,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=num_files * tokens_per_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,14 +121,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("How are you?", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -148,14 +148,14 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom instructions", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_msg1, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -167,25 +167,25 @@ class TestConstructMessageHistory:
|
||||
assert result[3] == custom_agent # Before last user message
|
||||
assert result[4] == user_msg2
|
||||
|
||||
def test_with_project_files(self) -> None:
|
||||
def test_with_context_files(self) -> None:
|
||||
"""Test that project files are inserted before the last user message."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg1 = create_message("First message", MessageType.USER, 5)
|
||||
user_msg2 = create_message("Second message", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, project_files_message, user2
|
||||
# Should have: system, user1, context_files_message, user2
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -202,14 +202,14 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Remember to cite sources", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -235,14 +235,14 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -264,18 +264,18 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, user1, custom_agent, project_files, user2, assistant_with_tool
|
||||
# Should have: system, user1, custom_agent, context_files, user2, assistant_with_tool
|
||||
assert len(result) == 6
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -292,14 +292,14 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Second", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2]
|
||||
project_files = create_project_files(num_files=0, num_images=2)
|
||||
context_files = create_context_files(num_files=0, num_images=2)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -332,14 +332,14 @@ class TestConstructMessageHistory:
|
||||
)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=0, num_images=1)
|
||||
context_files = create_context_files(num_files=0, num_images=1)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -366,7 +366,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg2,
|
||||
user_msg3,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last 3 messages + system (10 + 20 + 20 + 20 = 70 tokens)
|
||||
result = construct_message_history(
|
||||
@@ -374,7 +374,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -395,7 +395,7 @@ class TestConstructMessageHistory:
|
||||
tool_response = create_tool_response("tc_1", "tool_response", 20)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool, tool_response]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget only allows last user message and messages after + system
|
||||
# (10 + 20 + 20 + 20 = 70 tokens)
|
||||
@@ -404,7 +404,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=80,
|
||||
)
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestConstructMessageHistory:
|
||||
assistant_msg1,
|
||||
user_msg2,
|
||||
]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 10 tokens (30 total - 10 system - 10 last user):
|
||||
# keeps [tool_response, assistant_msg1] from history_before_last_user,
|
||||
@@ -442,7 +442,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=30,
|
||||
)
|
||||
|
||||
@@ -461,7 +461,7 @@ class TestConstructMessageHistory:
|
||||
user_msg2 = create_message("Latest question", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history = [user_msg1, assistant_with_tool, tool_response, user_msg2]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Remaining history budget is 25 tokens (45 total - 10 system - 10 last user):
|
||||
# keeps both assistant_with_tool and tool_response in history_before_last_user.
|
||||
@@ -470,7 +470,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=45,
|
||||
)
|
||||
|
||||
@@ -487,18 +487,18 @@ class TestConstructMessageHistory:
|
||||
reminder = create_message("Reminder", MessageType.USER, 10)
|
||||
|
||||
simple_chat_history: list[ChatMessageSimple] = []
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Should have: system, custom_agent, project_files, reminder
|
||||
# Should have: system, custom_agent, context_files, reminder
|
||||
assert len(result) == 4
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == custom_agent
|
||||
@@ -512,7 +512,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 5)
|
||||
|
||||
simple_chat_history = [assistant_msg, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
with pytest.raises(ValueError, match="No user message found"):
|
||||
construct_message_history(
|
||||
@@ -520,7 +520,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent = create_message("Custom", MessageType.USER, 50)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=1, tokens_per_file=100)
|
||||
context_files = create_context_files(num_files=1, tokens_per_file=100)
|
||||
|
||||
# Total required: 50 (system) + 50 (custom) + 100 (project) + 50 (user) = 250
|
||||
# But only 200 available
|
||||
@@ -541,7 +541,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=200,
|
||||
)
|
||||
|
||||
@@ -553,7 +553,7 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool = create_assistant_with_tool_call("tc_1", "tool", 30)
|
||||
|
||||
simple_chat_history = [user_msg1, user_msg2, assistant_with_tool]
|
||||
project_files = create_project_files()
|
||||
context_files = create_context_files()
|
||||
|
||||
# Budget: 50 tokens
|
||||
# Required: 10 (system) + 30 (user2) + 30 (assistant_with_tool) = 70 tokens
|
||||
@@ -566,7 +566,7 @@ class TestConstructMessageHistory:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=50,
|
||||
)
|
||||
|
||||
@@ -592,20 +592,20 @@ class TestConstructMessageHistory:
|
||||
assistant_with_tool,
|
||||
tool_response,
|
||||
]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=20)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=20)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=custom_agent,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
# Expected order:
|
||||
# system, user1, assistant1, user2, assistant2,
|
||||
# custom_agent, project_files, user3, assistant_with_tool, tool_response, reminder
|
||||
# custom_agent, context_files, user3, assistant_with_tool, tool_response, reminder
|
||||
assert len(result) == 11
|
||||
assert result[0] == system_prompt
|
||||
assert result[1] == user_msg1
|
||||
@@ -622,20 +622,20 @@ class TestConstructMessageHistory:
|
||||
assert result[9] == tool_response # After last user
|
||||
assert result[10] == reminder # At the very end
|
||||
|
||||
def test_project_files_json_format(self) -> None:
|
||||
def test_context_files_json_format(self) -> None:
|
||||
"""Test that project files are formatted correctly as JSON."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Hello", MessageType.USER, 5)
|
||||
|
||||
simple_chat_history = [user_msg]
|
||||
project_files = create_project_files(num_files=2, tokens_per_file=50)
|
||||
context_files = create_context_files(num_files=2, tokens_per_file=50)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -692,7 +692,7 @@ class TestForgottenFileMetadata:
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=create_project_files(),
|
||||
context_files=create_context_files(),
|
||||
available_tokens=available_tokens,
|
||||
token_counter=_simple_token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
|
||||
@@ -106,6 +106,9 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_ENDPOINT_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_tenant_ctx,
|
||||
patch(
|
||||
"onyx.server.metrics.postgres_connection_pool._connections_held"
|
||||
) as mock_gauge,
|
||||
@@ -114,12 +117,14 @@ def test_checkout_event_stores_endpoint_and_increments_gauge() -> None:
|
||||
mock_labels = MagicMock()
|
||||
mock_gauge.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "/api/chat/send-message"
|
||||
mock_tenant_ctx.get.return_value = "tenant_xyz"
|
||||
listeners["checkout"](None, conn_record, None)
|
||||
|
||||
assert conn_record.info["_metrics_endpoint"] == "/api/chat/send-message"
|
||||
assert conn_record.info["_metrics_tenant_id"] == "tenant_xyz"
|
||||
assert "_metrics_checkout_time" in conn_record.info
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/chat/send-message", engine="sync"
|
||||
handler="/api/chat/send-message", engine="sync", tenant_id="tenant_xyz"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
@@ -144,6 +149,7 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
conn_record = _make_conn_record()
|
||||
conn_record.info["_metrics_endpoint"] = "/api/search"
|
||||
conn_record.info["_metrics_tenant_id"] = "tenant_abc"
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic() - 0.5
|
||||
|
||||
with (
|
||||
@@ -162,7 +168,9 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="/api/search", engine="sync", tenant_id="tenant_abc"
|
||||
)
|
||||
mock_labels.dec.assert_called_once()
|
||||
mock_hist.labels.assert_called_with(handler="/api/search", engine="sync")
|
||||
mock_hist_labels.observe.assert_called_once()
|
||||
@@ -172,11 +180,12 @@ def test_checkin_event_observes_hold_duration() -> None:
|
||||
|
||||
# conn_record.info should be cleaned up
|
||||
assert "_metrics_endpoint" not in conn_record.info
|
||||
assert "_metrics_tenant_id" not in conn_record.info
|
||||
assert "_metrics_checkout_time" not in conn_record.info
|
||||
|
||||
|
||||
def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
"""Verify checkin gracefully handles missing endpoint info."""
|
||||
"""Verify checkin gracefully handles missing endpoint and tenant info."""
|
||||
engine = MagicMock()
|
||||
engine.pool = MagicMock()
|
||||
listeners: dict[str, Any] = {}
|
||||
@@ -207,7 +216,9 @@ def test_checkin_with_missing_endpoint_uses_unknown() -> None:
|
||||
|
||||
listeners["checkin"](None, conn_record)
|
||||
|
||||
mock_gauge.labels.assert_called_with(handler="unknown", engine="sync")
|
||||
mock_gauge.labels.assert_called_with(
|
||||
handler="unknown", engine="sync", tenant_id="unknown"
|
||||
)
|
||||
|
||||
|
||||
# --- setup_postgres_connection_pool_metrics tests ---
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.metrics.per_tenant import per_tenant_request_callback
|
||||
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
@@ -81,7 +82,7 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
mock_instance.add.assert_called_once()
|
||||
assert mock_instance.add.call_count == 3
|
||||
mock_instance.instrument.assert_called_once_with(
|
||||
app,
|
||||
latency_lowr_buckets=(
|
||||
@@ -100,6 +101,56 @@ def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_per_tenant_callback_increments_with_tenant_id() -> None:
|
||||
"""Verify per-tenant callback reads tenant from contextvar and increments."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = "tenant_abc"
|
||||
|
||||
info = _make_info(
|
||||
duration=0.1, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="tenant_abc",
|
||||
method="POST",
|
||||
handler="/api/chat",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_per_tenant_callback_falls_back_to_unknown() -> None:
|
||||
"""Verify per-tenant callback uses 'unknown' when contextvar is None."""
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.per_tenant.CURRENT_TENANT_ID_CONTEXTVAR"
|
||||
) as mock_ctx,
|
||||
patch("onyx.server.metrics.per_tenant._requests_by_tenant") as mock_counter,
|
||||
):
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
mock_ctx.get.return_value = None
|
||||
|
||||
info = _make_info(duration=0.1)
|
||||
per_tenant_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
tenant_id="unknown",
|
||||
method="GET",
|
||||
handler="/api/test",
|
||||
status="200",
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
|
||||
3
cli/.gitignore
vendored
Normal file
3
cli/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
onyx-cli
|
||||
cli
|
||||
onyx.cli
|
||||
85
cli/README.md
Normal file
85
cli/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Onyx CLI
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) assistant. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# From source
|
||||
cd cli
|
||||
go build -o onyx-cli .
|
||||
|
||||
# Or install directly
|
||||
go install github.com/onyx-dot-app/onyx/cli@latest
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Launch interactive chat (default)
|
||||
./onyx-cli
|
||||
|
||||
# First-run setup
|
||||
./onyx-cli configure
|
||||
|
||||
# One-shot question
|
||||
./onyx-cli ask "What is Onyx?"
|
||||
./onyx-cli ask --agent-id 5 "Summarize this topic"
|
||||
./onyx-cli ask --json "Hello"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `chat` | Launch the interactive chat TUI (default) |
|
||||
| `ask` | Ask a one-shot question (non-interactive) |
|
||||
| `configure` | Configure server URL and API key |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show help message |
|
||||
| `/new` | Start a new chat session |
|
||||
| `/agent` | List and switch agents |
|
||||
| `/attach <path>` | Attach a file to next message |
|
||||
| `/sessions` | List recent chat sessions |
|
||||
| `/resume <id>` | Resume a previous session |
|
||||
| `/clear` | Clear the chat display |
|
||||
| `/configure` | Re-run connection setup |
|
||||
| `/connectors` | Open connectors in browser |
|
||||
| `/settings` | Open settings in browser |
|
||||
| `/quit` | Exit Onyx CLI |
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `Enter` | Send message |
|
||||
| `Escape` | Cancel current generation |
|
||||
| `Ctrl+O` | Toggle source citations |
|
||||
| `Ctrl+D` | Quit (press twice) |
|
||||
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
|
||||
| `Page Up` / `Page Down` | Scroll half page |
|
||||
|
||||
## Configuration
|
||||
|
||||
Config is stored at `~/.config/onyx-cli/config.json`. Environment variables override file values:
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | Server URL |
|
||||
| `ONYX_API_KEY` | API key |
|
||||
| `DANSWER_API_KEY` | Legacy API key (fallback) |
|
||||
| `ONYX_PERSONA_ID` | Default persona ID |
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
go test ./...
|
||||
|
||||
# Build
|
||||
go build -o onyx-cli .
|
||||
```
|
||||
80
cli/cmd/ask.go
Normal file
80
cli/cmd/ask.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
)
|
||||
|
||||
var askCmd = &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
ch := client.SendMessageStream(
|
||||
context.Background(),
|
||||
question,
|
||||
nil,
|
||||
agentID,
|
||||
&parentID,
|
||||
nil,
|
||||
)
|
||||
|
||||
for event := range ch {
|
||||
if askJSON {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marshaling event: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
continue
|
||||
}
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
case models.ErrorEvent:
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
askCmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
askCmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
rootCmd.AddCommand(askCmd)
|
||||
}
|
||||
41
cli/cmd/chat.go
Normal file
41
cli/cmd/chat.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var chatCmd = &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
// First-run: onboarding
|
||||
if !config.ConfigExists() || !cfg.IsConfigured() {
|
||||
result := onboarding.Run(&cfg)
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
if _, err := p.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(chatCmd)
|
||||
}
|
||||
21
cli/cmd/configure.go
Normal file
21
cli/cmd/configure.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var configureCmd = &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(configureCmd)
|
||||
}
|
||||
31
cli/cmd/root.go
Normal file
31
cli/cmd/root.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Package cmd implements Cobra CLI commands for the Onyx CLI.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const version = "0.1.0"
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "onyx-cli",
|
||||
Short: "Terminal UI for chatting with Onyx",
|
||||
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.Version = version
|
||||
// Default command is chat
|
||||
rootCmd.RunE = chatCmd.RunE
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
45
cli/go.mod
Normal file
45
cli/go.mod
Normal file
@@ -0,0 +1,45 @@
|
||||
module github.com/onyx-dot-app/onyx/cli
|
||||
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.3.4
|
||||
github.com/charmbracelet/glamour v0.8.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/spf13/cobra v1.9.1
|
||||
golang.org/x/term v0.22.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yuin/goldmark v1.7.4 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.3 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
)
|
||||
98
cli/go.sum
Normal file
98
cli/go.sum
Normal file
@@ -0,0 +1,98 @@
|
||||
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
|
||||
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
||||
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
|
||||
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
|
||||
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
|
||||
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
|
||||
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
279
cli/internal/api/client.go
Normal file
279
cli/internal/api/client.go
Normal file
@@ -0,0 +1,279 @@
|
||||
// Package api provides the HTTP client for communicating with the Onyx server.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Client is the Onyx API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client // default 30s timeout for quick requests
|
||||
longHTTPClient *http.Client // 5min timeout for streaming/uploads
|
||||
}
|
||||
|
||||
// NewClient creates a new API client from config.
|
||||
func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
apiKey: cfg.APIKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
},
|
||||
longHTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the client's config.
|
||||
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
|
||||
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
|
||||
c.apiKey = cfg.APIKey
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), method, c.baseURL+path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(method, path string, reqBody any, result any) error {
|
||||
var body io.Reader
|
||||
if reqBody != nil {
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := c.newRequest(method, path, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reqBody != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
return json.NewDecoder(resp.Body).Decode(result)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the server is reachable and credentials are valid.
|
||||
// Returns nil on success, or an error with a descriptive message on failure.
|
||||
func (c *Client) TestConnection() error {
|
||||
// Step 1: Basic reachability
|
||||
req, err := c.newRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
serverHeader := strings.ToLower(resp.Header.Get("Server"))
|
||||
|
||||
if resp.StatusCode == 403 {
|
||||
if strings.Contains(serverHeader, "awselb") || strings.Contains(serverHeader, "amazons3") {
|
||||
return fmt.Errorf("blocked by AWS load balancer (HTTP 403 on all requests).\n Your IP address may not be in the ALB's security group or WAF allowlist")
|
||||
}
|
||||
return fmt.Errorf("HTTP 403 on base URL — the server is blocking all traffic.\n This is likely a firewall, WAF, or IP allowlist restriction")
|
||||
}
|
||||
|
||||
// Step 2: Authenticated check
|
||||
req2, err := c.newRequest("GET", "/api/chat/get-user-chat-sessions", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
|
||||
resp2, err := c.longHTTPClient.Do(req2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
|
||||
if resp2.StatusCode == 200 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp2.Body, 300))
|
||||
body := string(bodyBytes)
|
||||
isHTML := strings.HasPrefix(strings.TrimSpace(body), "<")
|
||||
respServer := strings.ToLower(resp2.Header.Get("Server"))
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
if body != "" {
|
||||
detail += fmt.Sprintf("\n Response: %s", body)
|
||||
}
|
||||
return fmt.Errorf("%s", detail)
|
||||
}
|
||||
|
||||
// ListAgents returns visible agents.
|
||||
func (c *Client) ListAgents() ([]models.AgentSummary, error) {
|
||||
var raw []models.AgentSummary
|
||||
if err := c.doJSON("GET", "/api/persona", nil, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []models.AgentSummary
|
||||
for _, p := range raw {
|
||||
if p.IsVisible {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListChatSessions returns recent chat sessions.
|
||||
func (c *Client) ListChatSessions() ([]models.ChatSessionDetails, error) {
|
||||
var resp struct {
|
||||
Sessions []models.ChatSessionDetails `json:"sessions"`
|
||||
}
|
||||
if err := c.doJSON("GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Sessions, nil
|
||||
}
|
||||
|
||||
// GetChatSession returns full details for a session.
|
||||
func (c *Client) GetChatSession(sessionID string) (*models.ChatSessionDetailResponse, error) {
|
||||
var resp models.ChatSessionDetailResponse
|
||||
if err := c.doJSON("GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RenameChatSession renames a session. If name is empty, the backend auto-generates one.
|
||||
func (c *Client) RenameChatSession(sessionID string, name *string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
if name != nil {
|
||||
payload["name"] = *name
|
||||
}
|
||||
var resp struct {
|
||||
NewName string `json:"new_name"`
|
||||
}
|
||||
if err := c.doJSON("PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.NewName, nil
|
||||
}
|
||||
|
||||
// UploadFile uploads a file and returns a file descriptor.
|
||||
func (c *Client) UploadFile(filePath string) (*models.FileDescriptorPayload, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
part, err := writer.CreateFormFile("files", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req, err := c.newRequest("POST", "/api/user/projects/file/upload", &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(body)}
|
||||
}
|
||||
|
||||
var snapshot models.CategorizedFilesSnapshot
|
||||
if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(snapshot.UserFiles) == 0 {
|
||||
return nil, &OnyxAPIError{StatusCode: 400, Detail: "File upload returned no files"}
|
||||
}
|
||||
|
||||
uf := snapshot.UserFiles[0]
|
||||
return &models.FileDescriptorPayload{
|
||||
ID: uf.FileID,
|
||||
Type: uf.ChatFileType,
|
||||
Name: filepath.Base(filePath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StopChatSession sends a stop signal for a streaming session (best-effort).
|
||||
func (c *Client) StopChatSession(sessionID string) {
|
||||
req, err := c.newRequest("POST", "/api/chat/stop-chat-session/"+sessionID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
13
cli/internal/api/errors.go
Normal file
13
cli/internal/api/errors.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package api
|
||||
|
||||
import "fmt"
|
||||
|
||||
// OnyxAPIError is returned when an Onyx API call fails.
|
||||
type OnyxAPIError struct {
|
||||
StatusCode int
|
||||
Detail string
|
||||
}
|
||||
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
133
cli/internal/api/stream.go
Normal file
133
cli/internal/api/stream.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/parser"
|
||||
)
|
||||
|
||||
// StreamEventMsg wraps a StreamEvent for Bubble Tea.
|
||||
type StreamEventMsg struct {
|
||||
Event models.StreamEvent
|
||||
}
|
||||
|
||||
// StreamDoneMsg signals the stream has ended.
|
||||
type StreamDoneMsg struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SendMessageStream starts streaming a chat message response.
|
||||
// It reads NDJSON lines, parses them, and sends events on the returned channel.
|
||||
// The goroutine stops when ctx is cancelled or the stream ends.
|
||||
func (c *Client) SendMessageStream(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
chatSessionID *string,
|
||||
agentID int,
|
||||
parentMessageID *int,
|
||||
fileDescriptors []models.FileDescriptorPayload,
|
||||
) <-chan models.StreamEvent {
|
||||
ch := make(chan models.StreamEvent, 64)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
payload := models.SendMessagePayload{
|
||||
Message: message,
|
||||
ParentMessageID: parentMessageID,
|
||||
FileDescriptors: fileDescriptors,
|
||||
Origin: "api",
|
||||
IncludeCitations: true,
|
||||
Stream: true,
|
||||
}
|
||||
if payload.FileDescriptors == nil {
|
||||
payload.FileDescriptors = []models.FileDescriptorPayload{}
|
||||
}
|
||||
|
||||
if chatSessionID != nil {
|
||||
payload.ChatSessionID = chatSessionID
|
||||
} else {
|
||||
payload.ChatSessionInfo = &models.ChatSessionCreationInfo{AgentID: agentID}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("marshal error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
req.ContentLength = int64(len(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // cancelled
|
||||
}
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("connection error: %v", err), IsRetryable: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var respBody [4096]byte
|
||||
n, _ := resp.Body.Read(respBody[:])
|
||||
ch <- models.ErrorEvent{
|
||||
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody[:n])),
|
||||
IsRetryable: resp.StatusCode >= 500,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
event := parser.ParseStreamLine(scanner.Text())
|
||||
if event != nil {
|
||||
select {
|
||||
case ch <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// WaitForStreamEvent returns a tea.Cmd that reads one event from the channel.
|
||||
// On channel close, it returns StreamDoneMsg.
|
||||
func WaitForStreamEvent(ch <-chan models.StreamEvent) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
event, ok := <-ch
|
||||
if !ok {
|
||||
return StreamDoneMsg{}
|
||||
}
|
||||
return StreamEventMsg{Event: event}
|
||||
}
|
||||
}
|
||||
|
||||
102
cli/internal/config/config.go
Normal file
102
cli/internal/config/config.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAPIKeyLegacy = "DANSWER_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
)
|
||||
|
||||
// OnyxCliConfig holds the CLI configuration.
|
||||
type OnyxCliConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a config with default values.
|
||||
func DefaultConfig() OnyxCliConfig {
|
||||
return OnyxCliConfig{
|
||||
ServerURL: "http://localhost:3000",
|
||||
APIKey: "",
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// IsConfigured returns true if the config has an API key.
|
||||
func (c OnyxCliConfig) IsConfigured() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
// configDir returns ~/.config/onyx-cli
|
||||
func configDir() string {
|
||||
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
|
||||
return filepath.Join(xdg, "onyx-cli")
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".", ".config", "onyx-cli")
|
||||
}
|
||||
return filepath.Join(home, ".config", "onyx-cli")
|
||||
}
|
||||
|
||||
// ConfigFilePath returns the full path to the config file.
|
||||
func ConfigFilePath() string {
|
||||
return filepath.Join(configDir(), "config.json")
|
||||
}
|
||||
|
||||
// ConfigExists checks if the config file exists on disk.
|
||||
func ConfigExists() bool {
|
||||
_, err := os.Stat(ConfigFilePath())
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
if err == nil {
|
||||
// Ignore JSON errors - fall back to defaults
|
||||
_ = json.Unmarshal(data, &cfg)
|
||||
}
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
}
|
||||
if v := os.Getenv(EnvAPIKey); v != "" {
|
||||
cfg.APIKey = v
|
||||
} else if v := os.Getenv(EnvAPIKeyLegacy); v != "" {
|
||||
cfg.APIKey = v
|
||||
}
|
||||
if v := os.Getenv(EnvAgentID); v != "" {
|
||||
if id, err := strconv.Atoi(v); err == nil {
|
||||
cfg.DefaultAgentID = id
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Save writes the config to disk, creating parent directories if needed.
|
||||
func Save(cfg OnyxCliConfig) error {
|
||||
dir := configDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(ConfigFilePath(), data, 0o644)
|
||||
}
|
||||
222
cli/internal/config/config_test.go
Normal file
222
cli/internal/config/config_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
func clearEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAPIKeyLegacy, EnvAgentID} {
|
||||
t.Setenv(key, "")
|
||||
os.Unsetenv(key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.ServerURL != "http://localhost:3000" {
|
||||
t.Errorf("expected default server URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty API key, got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("expected default agent ID 0, got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsConfigured(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.IsConfigured() {
|
||||
t.Error("empty config should not be configured")
|
||||
}
|
||||
cfg.APIKey = "some-key"
|
||||
if !cfg.IsConfigured() {
|
||||
t.Error("config with API key should be configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "http://localhost:3000" {
|
||||
t.Errorf("expected default URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty key, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
onyxDir := filepath.Join(dir, "onyx-cli")
|
||||
os.MkdirAll(onyxDir, 0o755)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://my-onyx.example.com",
|
||||
"api_key": "test-key-123",
|
||||
"default_persona_id": 5,
|
||||
})
|
||||
os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://my-onyx.example.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "test-key-123" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 5 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCorruptFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
onyxDir := filepath.Join(dir, "onyx-cli")
|
||||
os.MkdirAll(onyxDir, 0o755)
|
||||
os.WriteFile(filepath.Join(onyxDir, "config.json"), []byte("not valid json {{{"), 0o644)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "http://localhost:3000" {
|
||||
t.Errorf("expected default URL on corrupt file, got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideServerURL(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvServerURL, "https://env-override.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-override.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAPIKey(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAPIKey, "env-key")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.APIKey != "env-key" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideLegacyAPIKey(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAPIKeyLegacy, "legacy-key")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.APIKey != "legacy-key" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "42")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 42 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideInvalidAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "not-a-number")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverridesFileValues(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
onyxDir := filepath.Join(dir, "onyx-cli")
|
||||
os.MkdirAll(onyxDir, 0o755)
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://file-url.com",
|
||||
"api_key": "file-key",
|
||||
})
|
||||
os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644)
|
||||
|
||||
t.Setenv(EnvServerURL, "https://env-url.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-url.com" {
|
||||
t.Errorf("env should override file, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "file-key" {
|
||||
t.Errorf("file value should be kept, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndReload(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := OnyxCliConfig{
|
||||
ServerURL: "https://saved.example.com",
|
||||
APIKey: "saved-key",
|
||||
DefaultAgentID: 10,
|
||||
}
|
||||
if err := Save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded := Load()
|
||||
if loaded.ServerURL != "https://saved.example.com" {
|
||||
t.Errorf("got %s", loaded.ServerURL)
|
||||
}
|
||||
if loaded.APIKey != "saved-key" {
|
||||
t.Errorf("got %s", loaded.APIKey)
|
||||
}
|
||||
if loaded.DefaultAgentID != 10 {
|
||||
t.Errorf("got %d", loaded.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCreatesParentDirs(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
nested := filepath.Join(dir, "deep", "nested")
|
||||
t.Setenv("XDG_CONFIG_HOME", nested)
|
||||
|
||||
if err := Save(OnyxCliConfig{APIKey: "test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !ConfigExists() {
|
||||
t.Error("config file should exist after save")
|
||||
}
|
||||
}
|
||||
193
cli/internal/models/events.go
Normal file
193
cli/internal/models/events.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package models
|
||||
|
||||
// StreamEvent is the interface for all parsed stream events.
|
||||
type StreamEvent interface {
|
||||
EventType() string
|
||||
}
|
||||
|
||||
// Event type constants matching the Python StreamEventType enum.
|
||||
const (
|
||||
EventSessionCreated = "session_created"
|
||||
EventMessageIDInfo = "message_id_info"
|
||||
EventStop = "stop"
|
||||
EventError = "error"
|
||||
EventMessageStart = "message_start"
|
||||
EventMessageDelta = "message_delta"
|
||||
EventSearchStart = "search_tool_start"
|
||||
EventSearchQueries = "search_tool_queries_delta"
|
||||
EventSearchDocuments = "search_tool_documents_delta"
|
||||
EventReasoningStart = "reasoning_start"
|
||||
EventReasoningDelta = "reasoning_delta"
|
||||
EventReasoningDone = "reasoning_done"
|
||||
EventCitationInfo = "citation_info"
|
||||
EventOpenURLStart = "open_url_start"
|
||||
EventImageGenStart = "image_generation_start"
|
||||
EventPythonToolStart = "python_tool_start"
|
||||
EventCustomToolStart = "custom_tool_start"
|
||||
EventFileReaderStart = "file_reader_start"
|
||||
EventDeepResearchPlan = "deep_research_plan_start"
|
||||
EventDeepResearchDelta = "deep_research_plan_delta"
|
||||
EventResearchAgentStart = "research_agent_start"
|
||||
EventIntermediateReport = "intermediate_report_start"
|
||||
EventIntermediateReportDt = "intermediate_report_delta"
|
||||
EventUnknown = "unknown"
|
||||
)
|
||||
|
||||
// SessionCreatedEvent is emitted when a new chat session is created.
|
||||
type SessionCreatedEvent struct {
|
||||
ChatSessionID string
|
||||
}
|
||||
|
||||
func (e SessionCreatedEvent) EventType() string { return EventSessionCreated }
|
||||
|
||||
// MessageIDEvent carries the user and agent message IDs.
|
||||
type MessageIDEvent struct {
|
||||
UserMessageID *int
|
||||
ReservedAgentMessageID int
|
||||
}
|
||||
|
||||
func (e MessageIDEvent) EventType() string { return EventMessageIDInfo }
|
||||
|
||||
// StopEvent signals the end of a stream.
|
||||
type StopEvent struct {
|
||||
Placement *Placement
|
||||
StopReason *string
|
||||
}
|
||||
|
||||
func (e StopEvent) EventType() string { return EventStop }
|
||||
|
||||
// ErrorEvent signals an error.
|
||||
type ErrorEvent struct {
|
||||
Placement *Placement
|
||||
Error string
|
||||
StackTrace *string
|
||||
IsRetryable bool
|
||||
}
|
||||
|
||||
func (e ErrorEvent) EventType() string { return EventError }
|
||||
|
||||
// MessageStartEvent signals the beginning of an agent message.
|
||||
type MessageStartEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e MessageStartEvent) EventType() string { return EventMessageStart }
|
||||
|
||||
// MessageDeltaEvent carries a token of agent content.
|
||||
type MessageDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e MessageDeltaEvent) EventType() string { return EventMessageDelta }
|
||||
|
||||
// SearchStartEvent signals the beginning of a search.
|
||||
type SearchStartEvent struct {
|
||||
Placement *Placement
|
||||
IsInternetSearch bool
|
||||
}
|
||||
|
||||
func (e SearchStartEvent) EventType() string { return EventSearchStart }
|
||||
|
||||
// SearchQueriesEvent carries search queries.
|
||||
type SearchQueriesEvent struct {
|
||||
Placement *Placement
|
||||
Queries []string
|
||||
}
|
||||
|
||||
func (e SearchQueriesEvent) EventType() string { return EventSearchQueries }
|
||||
|
||||
// SearchDocumentsEvent carries found documents.
|
||||
type SearchDocumentsEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e SearchDocumentsEvent) EventType() string { return EventSearchDocuments }
|
||||
|
||||
// ReasoningStartEvent signals the beginning of a reasoning block.
|
||||
type ReasoningStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningStartEvent) EventType() string { return EventReasoningStart }
|
||||
|
||||
// ReasoningDeltaEvent carries reasoning text.
|
||||
type ReasoningDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Reasoning string
|
||||
}
|
||||
|
||||
func (e ReasoningDeltaEvent) EventType() string { return EventReasoningDelta }
|
||||
|
||||
// ReasoningDoneEvent signals the end of reasoning.
|
||||
type ReasoningDoneEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningDoneEvent) EventType() string { return EventReasoningDone }
|
||||
|
||||
// CitationEvent carries citation info.
|
||||
type CitationEvent struct {
|
||||
Placement *Placement
|
||||
CitationNumber int
|
||||
DocumentID string
|
||||
}
|
||||
|
||||
func (e CitationEvent) EventType() string { return EventCitationInfo }
|
||||
|
||||
// ToolStartEvent signals the start of a tool usage.
|
||||
type ToolStartEvent struct {
|
||||
Placement *Placement
|
||||
Type string // The specific event type (e.g. "open_url_start")
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e ToolStartEvent) EventType() string { return e.Type }
|
||||
|
||||
// DeepResearchPlanStartEvent signals the start of a deep research plan.
|
||||
type DeepResearchPlanStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanStartEvent) EventType() string { return EventDeepResearchPlan }
|
||||
|
||||
// DeepResearchPlanDeltaEvent carries deep research plan content.
|
||||
type DeepResearchPlanDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanDeltaEvent) EventType() string { return EventDeepResearchDelta }
|
||||
|
||||
// ResearchAgentStartEvent signals a research sub-task.
|
||||
type ResearchAgentStartEvent struct {
|
||||
Placement *Placement
|
||||
ResearchTask string
|
||||
}
|
||||
|
||||
func (e ResearchAgentStartEvent) EventType() string { return EventResearchAgentStart }
|
||||
|
||||
// IntermediateReportStartEvent signals the start of an intermediate report.
|
||||
type IntermediateReportStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e IntermediateReportStartEvent) EventType() string { return EventIntermediateReport }
|
||||
|
||||
// IntermediateReportDeltaEvent carries intermediate report content.
|
||||
type IntermediateReportDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e IntermediateReportDeltaEvent) EventType() string { return EventIntermediateReportDt }
|
||||
|
||||
// UnknownEvent is a catch-all for unrecognized stream data.
|
||||
type UnknownEvent struct {
|
||||
Placement *Placement
|
||||
RawData map[string]any
|
||||
}
|
||||
|
||||
func (e UnknownEvent) EventType() string { return EventUnknown }
|
||||
112
cli/internal/models/models.go
Normal file
112
cli/internal/models/models.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package models defines API request/response types for the Onyx CLI.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// AgentSummary represents an agent from the API.
|
||||
type AgentSummary struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IsDefaultPersona bool `json:"is_default_persona"`
|
||||
IsVisible bool `json:"is_visible"`
|
||||
}
|
||||
|
||||
// ChatSessionSummary is a brief session listing.
|
||||
type ChatSessionSummary struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created time.Time `json:"time_created"`
|
||||
}
|
||||
|
||||
// ChatSessionDetails is a session with timestamps as strings.
|
||||
type ChatSessionDetails struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created string `json:"time_created"`
|
||||
Updated string `json:"time_updated"`
|
||||
}
|
||||
|
||||
// ChatMessageDetail is a single message in a session.
|
||||
type ChatMessageDetail struct {
|
||||
MessageID int `json:"message_id"`
|
||||
ParentMessage *int `json:"parent_message"`
|
||||
LatestChildMessage *int `json:"latest_child_message"`
|
||||
Message string `json:"message"`
|
||||
MessageType string `json:"message_type"`
|
||||
TimeSent string `json:"time_sent"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
// ChatSessionDetailResponse is the full session detail from the API.
|
||||
type ChatSessionDetailResponse struct {
|
||||
ChatSessionID string `json:"chat_session_id"`
|
||||
Description *string `json:"description"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
AgentName *string `json:"persona_name"`
|
||||
Messages []ChatMessageDetail `json:"messages"`
|
||||
}
|
||||
|
||||
// ChatFileType represents a file type for uploads.
|
||||
type ChatFileType string
|
||||
|
||||
const (
|
||||
ChatFileImage ChatFileType = "image"
|
||||
ChatFileDoc ChatFileType = "document"
|
||||
ChatFilePlainText ChatFileType = "plain_text"
|
||||
ChatFileCSV ChatFileType = "csv"
|
||||
)
|
||||
|
||||
// FileDescriptorPayload is a file descriptor for send-message requests.
|
||||
type FileDescriptorPayload struct {
|
||||
ID string `json:"id"`
|
||||
Type ChatFileType `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// UserFileSnapshot represents an uploaded file.
|
||||
type UserFileSnapshot struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FileID string `json:"file_id"`
|
||||
ChatFileType ChatFileType `json:"chat_file_type"`
|
||||
}
|
||||
|
||||
// CategorizedFilesSnapshot is the response from file upload.
|
||||
type CategorizedFilesSnapshot struct {
|
||||
UserFiles []UserFileSnapshot `json:"user_files"`
|
||||
}
|
||||
|
||||
// ChatSessionCreationInfo is included when creating a new session inline.
|
||||
type ChatSessionCreationInfo struct {
|
||||
AgentID int `json:"persona_id"`
|
||||
}
|
||||
|
||||
// SendMessagePayload is the request body for POST /api/chat/send-chat-message.
|
||||
type SendMessagePayload struct {
|
||||
Message string `json:"message"`
|
||||
ChatSessionID *string `json:"chat_session_id,omitempty"`
|
||||
ChatSessionInfo *ChatSessionCreationInfo `json:"chat_session_info,omitempty"`
|
||||
ParentMessageID *int `json:"parent_message_id"`
|
||||
FileDescriptors []FileDescriptorPayload `json:"file_descriptors"`
|
||||
Origin string `json:"origin"`
|
||||
IncludeCitations bool `json:"include_citations"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// SearchDoc represents a document found during search.
|
||||
type SearchDoc struct {
|
||||
DocumentID string `json:"document_id"`
|
||||
SemanticIdentifier string `json:"semantic_identifier"`
|
||||
Link *string `json:"link"`
|
||||
SourceType string `json:"source_type"`
|
||||
}
|
||||
|
||||
// Placement indicates where a stream event belongs in the conversation.
|
||||
type Placement struct {
|
||||
TurnIndex int `json:"turn_index"`
|
||||
TabIndex int `json:"tab_index"`
|
||||
SubTurnIndex *int `json:"sub_turn_index"`
|
||||
}
|
||||
144
cli/internal/onboarding/onboarding.go
Normal file
144
cli/internal/onboarding/onboarding.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Package onboarding handles the first-run setup flow for Onyx CLI.
|
||||
package onboarding
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var (
|
||||
boldStyle = lipgloss.NewStyle().Bold(true)
|
||||
dimStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#555577"))
|
||||
greenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00cc66")).Bold(true)
|
||||
redStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff5555")).Bold(true)
|
||||
yellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
)
|
||||
|
||||
func getTermSize() (int, int) {
|
||||
w, h, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return 80, 24
|
||||
}
|
||||
return w, h
|
||||
}
|
||||
|
||||
// Run executes the interactive onboarding flow.
|
||||
// Returns the validated config, or nil if the user cancels.
|
||||
func Run(existing *config.OnyxCliConfig) *config.OnyxCliConfig {
|
||||
cfg := config.DefaultConfig()
|
||||
if existing != nil {
|
||||
cfg = *existing
|
||||
}
|
||||
|
||||
w, h := getTermSize()
|
||||
fmt.Print(tui.RenderSplashOnboarding(w, h))
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println(" Welcome to " + boldStyle.Render("Onyx CLI") + ".")
|
||||
fmt.Println()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
// Server URL
|
||||
serverURL := prompt(reader, " Onyx server URL", cfg.ServerURL)
|
||||
if serverURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// API Key
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Need an API key? Press Enter to open the admin panel in your browser,"))
|
||||
fmt.Println(" " + dimStyle.Render("or paste your key below."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey := prompt(reader, " API key", cfg.APIKey)
|
||||
|
||||
if apiKey == "" {
|
||||
// Open browser to API key page
|
||||
url := strings.TrimRight(serverURL, "/") + "/admin/api-key"
|
||||
fmt.Printf("\n Opening %s ...\n", url)
|
||||
util.OpenBrowser(url)
|
||||
fmt.Println(" " + dimStyle.Render("Copy your API key, then paste it here."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey = prompt(reader, " API key", "")
|
||||
if apiKey == "" {
|
||||
fmt.Println("\n " + redStyle.Render("No API key provided. Exiting."))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Test connection
|
||||
cfg = config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: cfg.DefaultAgentID,
|
||||
}
|
||||
|
||||
fmt.Println("\n " + yellowStyle.Render("Testing connection..."))
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Connection failed.") + " " + err.Error())
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Run ") + boldStyle.Render("onyx-cli configure") + dimStyle.Render(" to try again."))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Could not save config: "+err.Error()))
|
||||
}
|
||||
fmt.Println(" " + greenStyle.Render("Connected and authenticated."))
|
||||
fmt.Println()
|
||||
printQuickStart()
|
||||
return &cfg
|
||||
}
|
||||
|
||||
func prompt(reader *bufio.Reader, label, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf("%s %s: ", label, dimStyle.Render("["+defaultVal+"]"))
|
||||
} else {
|
||||
fmt.Printf("%s: ", label)
|
||||
}
|
||||
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func printQuickStart() {
|
||||
fmt.Println(" " + boldStyle.Render("Quick start"))
|
||||
fmt.Println()
|
||||
fmt.Println(" Just type to chat with your Onyx agent.")
|
||||
fmt.Println()
|
||||
|
||||
rows := [][2]string{
|
||||
{"/help", "Show all commands"},
|
||||
{"/attach", "Attach a file"},
|
||||
{"/agent", "Switch agent"},
|
||||
{"/new", "New conversation"},
|
||||
{"/sessions", "Browse previous chats"},
|
||||
{"Esc", "Cancel generation"},
|
||||
{"Ctrl+D", "Quit"},
|
||||
}
|
||||
for _, r := range rows {
|
||||
fmt.Printf(" %-12s %s\n", boldStyle.Render(r[0]), dimStyle.Render(r[1]))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
247
cli/internal/parser/parser.go
Normal file
247
cli/internal/parser/parser.go
Normal file
@@ -0,0 +1,247 @@
|
||||
// Package parser handles NDJSON stream parsing for Onyx chat responses.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// ParseStreamLine parses a single NDJSON line into a typed StreamEvent.
|
||||
// Returns nil for empty lines or unparseable content.
|
||||
func ParseStreamLine(line string) models.StreamEvent {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &data); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Case 1: CreateChatSessionID
|
||||
if _, ok := data["chat_session_id"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
sid, _ := data["chat_session_id"].(string)
|
||||
return models.SessionCreatedEvent{ChatSessionID: sid}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: MessageResponseIDInfo
|
||||
if _, ok := data["reserved_assistant_message_id"]; ok {
|
||||
reservedID := jsonInt(data["reserved_assistant_message_id"])
|
||||
var userMsgID *int
|
||||
if v, ok := data["user_message_id"]; ok && v != nil {
|
||||
id := jsonInt(v)
|
||||
userMsgID = &id
|
||||
}
|
||||
return models.MessageIDEvent{
|
||||
UserMessageID: userMsgID,
|
||||
ReservedAgentMessageID: reservedID,
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: StreamingError (top-level error without placement)
|
||||
if _, ok := data["error"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
errStr, _ := data["error"].(string)
|
||||
var stackTrace *string
|
||||
if st, ok := data["stack_trace"].(string); ok {
|
||||
stackTrace = &st
|
||||
}
|
||||
isRetryable := true
|
||||
if v, ok := data["is_retryable"].(bool); ok {
|
||||
isRetryable = v
|
||||
}
|
||||
return models.ErrorEvent{
|
||||
Error: errStr,
|
||||
StackTrace: stackTrace,
|
||||
IsRetryable: isRetryable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 4: Packet with placement + obj
|
||||
if rawPlacement, ok := data["placement"]; ok {
|
||||
if rawObj, ok := data["obj"]; ok {
|
||||
placement := parsePlacement(rawPlacement)
|
||||
obj, _ := rawObj.(map[string]any)
|
||||
if obj == nil {
|
||||
return models.UnknownEvent{Placement: placement, RawData: data}
|
||||
}
|
||||
return parsePacketObj(obj, placement)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
return models.UnknownEvent{RawData: data}
|
||||
}
|
||||
|
||||
func parsePlacement(raw interface{}) *models.Placement {
|
||||
m, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
p := &models.Placement{
|
||||
TurnIndex: jsonInt(m["turn_index"]),
|
||||
TabIndex: jsonInt(m["tab_index"]),
|
||||
}
|
||||
if v, ok := m["sub_turn_index"]; ok && v != nil {
|
||||
st := jsonInt(v)
|
||||
p.SubTurnIndex = &st
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func parsePacketObj(obj map[string]any, placement *models.Placement) models.StreamEvent {
|
||||
objType, _ := obj["type"].(string)
|
||||
|
||||
switch objType {
|
||||
case "stop":
|
||||
var reason *string
|
||||
if r, ok := obj["stop_reason"].(string); ok {
|
||||
reason = &r
|
||||
}
|
||||
return models.StopEvent{Placement: placement, StopReason: reason}
|
||||
|
||||
case "error":
|
||||
errMsg := "Unknown error"
|
||||
if e, ok := obj["exception"]; ok {
|
||||
errMsg = toString(e)
|
||||
}
|
||||
return models.ErrorEvent{Placement: placement, Error: errMsg, IsRetryable: true}
|
||||
|
||||
case "message_start":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["final_documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.MessageStartEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "message_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.MessageDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "search_tool_start":
|
||||
isInternet, _ := obj["is_internet_search"].(bool)
|
||||
return models.SearchStartEvent{Placement: placement, IsInternetSearch: isInternet}
|
||||
|
||||
case "search_tool_queries_delta":
|
||||
var queries []string
|
||||
if raw, ok := obj["queries"].([]any); ok {
|
||||
for _, q := range raw {
|
||||
if s, ok := q.(string); ok {
|
||||
queries = append(queries, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return models.SearchQueriesEvent{Placement: placement, Queries: queries}
|
||||
|
||||
case "search_tool_documents_delta":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.SearchDocumentsEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "reasoning_start":
|
||||
return models.ReasoningStartEvent{Placement: placement}
|
||||
|
||||
case "reasoning_delta":
|
||||
reasoning, _ := obj["reasoning"].(string)
|
||||
return models.ReasoningDeltaEvent{Placement: placement, Reasoning: reasoning}
|
||||
|
||||
case "reasoning_done":
|
||||
return models.ReasoningDoneEvent{Placement: placement}
|
||||
|
||||
case "citation_info":
|
||||
return models.CitationEvent{
|
||||
Placement: placement,
|
||||
CitationNumber: jsonInt(obj["citation_number"]),
|
||||
DocumentID: jsonString(obj["document_id"]),
|
||||
}
|
||||
|
||||
case "open_url_start", "image_generation_start", "python_tool_start", "file_reader_start":
|
||||
toolName := strings.ReplaceAll(strings.TrimSuffix(objType, "_start"), "_", " ")
|
||||
toolName = cases.Title(language.English).String(toolName)
|
||||
return models.ToolStartEvent{Placement: placement, Type: objType, ToolName: toolName}
|
||||
|
||||
case "custom_tool_start":
|
||||
toolName := jsonString(obj["tool_name"])
|
||||
if toolName == "" {
|
||||
toolName = "Custom Tool"
|
||||
}
|
||||
return models.ToolStartEvent{Placement: placement, Type: models.EventCustomToolStart, ToolName: toolName}
|
||||
|
||||
case "deep_research_plan_start":
|
||||
return models.DeepResearchPlanStartEvent{Placement: placement}
|
||||
|
||||
case "deep_research_plan_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.DeepResearchPlanDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "research_agent_start":
|
||||
task, _ := obj["research_task"].(string)
|
||||
return models.ResearchAgentStartEvent{Placement: placement, ResearchTask: task}
|
||||
|
||||
case "intermediate_report_start":
|
||||
return models.IntermediateReportStartEvent{Placement: placement}
|
||||
|
||||
case "intermediate_report_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.IntermediateReportDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
default:
|
||||
return models.UnknownEvent{Placement: placement, RawData: obj}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSearchDocs(raw []any) []models.SearchDoc {
|
||||
var docs []models.SearchDoc
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
doc := models.SearchDoc{
|
||||
DocumentID: jsonString(m["document_id"]),
|
||||
SemanticIdentifier: jsonString(m["semantic_identifier"]),
|
||||
SourceType: jsonString(m["source_type"]),
|
||||
}
|
||||
if link, ok := m["link"].(string); ok {
|
||||
doc.Link = &link
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
func jsonInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func jsonString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
414
cli/internal/parser/parser_test.go
Normal file
414
cli/internal/parser/parser_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
func TestEmptyLineReturnsNil(t *testing.T) {
|
||||
for _, line := range []string{"", " ", "\n"} {
|
||||
if ParseStreamLine(line) != nil {
|
||||
t.Errorf("expected nil for %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidJSONReturnsNil(t *testing.T) {
|
||||
for _, line := range []string{"not json", "{broken"} {
|
||||
if ParseStreamLine(line) != nil {
|
||||
t.Errorf("expected nil for %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCreated(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"chat_session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SessionCreatedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SessionCreatedEvent, got %T", event)
|
||||
}
|
||||
if e.ChatSessionID != "550e8400-e29b-41d4-a716-446655440000" {
|
||||
t.Errorf("got %s", e.ChatSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfo(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": 1,
|
||||
"reserved_assistant_message_id": 2,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID == nil || *e.UserMessageID != 1 {
|
||||
t.Errorf("expected user_message_id=1")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 2 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfoNullUserID(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": nil,
|
||||
"reserved_assistant_message_id": 5,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID != nil {
|
||||
t.Error("expected nil user_message_id")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 5 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelError(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Rate limit exceeded",
|
||||
"stack_trace": "...",
|
||||
"is_retryable": true,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Rate limit exceeded" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if e.StackTrace == nil || *e.StackTrace != "..." {
|
||||
t.Error("expected stack_trace")
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected retryable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelErrorMinimal(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Something broke",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Something broke" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected default retryable=true")
|
||||
}
|
||||
}
|
||||
|
||||
func makePacket(obj map[string]interface{}, turnIndex, tabIndex int) string {
|
||||
return mustJSON(map[string]interface{}{
|
||||
"placement": map[string]interface{}{"turn_index": turnIndex, "tab_index": tabIndex},
|
||||
"obj": obj,
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPacket(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop", "stop_reason": "completed"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason == nil || *e.StopReason != "completed" {
|
||||
t.Error("expected stop_reason=completed")
|
||||
}
|
||||
if e.Placement == nil || e.Placement.TurnIndex != 0 {
|
||||
t.Error("expected placement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopPacketNoReason(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason != nil {
|
||||
t.Error("expected nil stop_reason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
_, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStartWithDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"final_documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "doc1", "semantic_identifier": "Doc 1"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 1 || e.Documents[0].DocumentID != "doc1" {
|
||||
t.Error("expected 1 document with id doc1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": "Hello"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Hello" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeltaEmpty(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": ""}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "" {
|
||||
t.Errorf("expected empty, got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_start", "is_internet_search": true,
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchStartEvent, got %T", event)
|
||||
}
|
||||
if !e.IsInternetSearch {
|
||||
t.Error("expected internet search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolQueries(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_queries_delta",
|
||||
"queries": []interface{}{"query 1", "query 2"},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchQueriesEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchQueriesEvent, got %T", event)
|
||||
}
|
||||
if len(e.Queries) != 2 || e.Queries[0] != "query 1" {
|
||||
t.Error("unexpected queries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_documents_delta",
|
||||
"documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "d1", "semantic_identifier": "First Doc", "link": "http://example.com"},
|
||||
map[string]interface{}{"document_id": "d2", "semantic_identifier": "Second Doc"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchDocumentsEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchDocumentsEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 2 {
|
||||
t.Errorf("expected 2 docs, got %d", len(e.Documents))
|
||||
}
|
||||
if e.Documents[0].Link == nil || *e.Documents[0].Link != "http://example.com" {
|
||||
t.Error("expected link on first doc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningStartEvent); !ok {
|
||||
t.Fatalf("expected ReasoningStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "reasoning_delta", "reasoning": "Let me think...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ReasoningDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReasoningDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Reasoning != "Let me think..." {
|
||||
t.Errorf("got %s", e.Reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDone(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_done"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningDoneEvent); !ok {
|
||||
t.Fatalf("expected ReasoningDoneEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationInfo(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "citation_info", "citation_number": 1, "document_id": "doc_abc",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.CitationEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected CitationEvent, got %T", event)
|
||||
}
|
||||
if e.CitationNumber != 1 || e.DocumentID != "doc_abc" {
|
||||
t.Errorf("got %d, %s", e.CitationNumber, e.DocumentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenURLStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "open_url_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.Type != "open_url_start" {
|
||||
t.Errorf("got type %s", e.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "python_tool_start", "code": "print('hi')",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "Python Tool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "custom_tool_start", "tool_name": "MyTool",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "MyTool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepResearchPlanDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "deep_research_plan_delta", "content": "Step 1: ...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.DeepResearchPlanDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected DeepResearchPlanDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Step 1: ..." {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResearchAgentStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "research_agent_start", "research_task": "Find info about X",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ResearchAgentStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResearchAgentStartEvent, got %T", event)
|
||||
}
|
||||
if e.ResearchTask != "Find info about X" {
|
||||
t.Errorf("got %s", e.ResearchTask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntermediateReportDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "intermediate_report_delta", "content": "Report text",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.IntermediateReportDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected IntermediateReportDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Report text" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownPacketType(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "section_end"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownTopLevel(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{"some_unknown_field": "value"})
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlacementPreserved(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_delta", "content": "x",
|
||||
}, 3, 1)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Placement == nil {
|
||||
t.Fatal("expected placement")
|
||||
}
|
||||
if e.Placement.TurnIndex != 3 || e.Placement.TabIndex != 1 {
|
||||
t.Errorf("got turn=%d tab=%d", e.Placement.TurnIndex, e.Placement.TabIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
626
cli/internal/tui/app.go
Normal file
626
cli/internal/tui/app.go
Normal file
@@ -0,0 +1,626 @@
|
||||
// Package tui implements the Bubble Tea TUI for Onyx CLI.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Model is the root Bubble Tea model.
|
||||
type Model struct {
|
||||
config config.OnyxCliConfig
|
||||
client *api.Client
|
||||
|
||||
viewport *viewport
|
||||
input inputModel
|
||||
status statusBar
|
||||
|
||||
width int
|
||||
height int
|
||||
|
||||
// Chat state
|
||||
chatSessionID *string
|
||||
agentID int
|
||||
agentName string
|
||||
agents []models.AgentSummary
|
||||
parentMessageID *int
|
||||
isStreaming bool
|
||||
streamCancel context.CancelFunc
|
||||
streamCh <-chan models.StreamEvent
|
||||
citations map[int]string
|
||||
attachedFiles []models.FileDescriptorPayload
|
||||
needsRename bool
|
||||
agentStarted bool
|
||||
|
||||
// Quit state
|
||||
quitPending bool
|
||||
splashShown bool
|
||||
initInputReady bool // true once terminal init responses have passed
|
||||
}
|
||||
|
||||
// NewModel creates a new TUI model.
|
||||
func NewModel(cfg config.OnyxCliConfig) Model {
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
|
||||
return Model{
|
||||
config: cfg,
|
||||
client: client,
|
||||
viewport: newViewport(80),
|
||||
input: newInputModel(),
|
||||
status: newStatusBar(),
|
||||
agentID: cfg.DefaultAgentID,
|
||||
agentName: "Default",
|
||||
parentMessageID: &parentID,
|
||||
citations: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the model.
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return loadAgentsCmd(m.client)
|
||||
}
|
||||
|
||||
// Update handles messages.
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Filter out terminal query responses (OSC 11 background color, cursor
|
||||
// position reports, etc.) that arrive as key events with raw escape content.
|
||||
// These arrive split across multiple key events, so we use a brief window
|
||||
// after startup to swallow them all.
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && !m.initInputReady {
|
||||
// During init, drop ALL key events — they're terminal query responses
|
||||
_ = keyMsg
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.viewport.setWidth(msg.Width)
|
||||
m.status.setWidth(msg.Width)
|
||||
m.input.textInput.Width = msg.Width - 4
|
||||
if !m.splashShown {
|
||||
m.splashShown = true
|
||||
// bottomHeight = sep + input + sep + status = 4 (approx)
|
||||
viewportHeight := msg.Height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = msg.Height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
// Delay input focus to let terminal query responses flush
|
||||
return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg {
|
||||
return inputReadyMsg{}
|
||||
})
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.MouseMsg:
|
||||
switch msg.Button {
|
||||
case tea.MouseButtonWheelUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
case tea.MouseButtonWheelDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
|
||||
case submitMsg:
|
||||
return m.handleSubmit(msg.text)
|
||||
|
||||
case fileDropMsg:
|
||||
return m.handleFileDrop(msg.path)
|
||||
|
||||
case InitDoneMsg:
|
||||
return m.handleInitDone(msg)
|
||||
|
||||
case api.StreamEventMsg:
|
||||
return m.handleStreamEvent(msg)
|
||||
|
||||
case api.StreamDoneMsg:
|
||||
return m.handleStreamDone(msg)
|
||||
|
||||
case AgentsLoadedMsg:
|
||||
return m.handleAgentsLoaded(msg)
|
||||
|
||||
case SessionsLoadedMsg:
|
||||
return m.handleSessionsLoaded(msg)
|
||||
|
||||
case SessionResumedMsg:
|
||||
return m.handleSessionResumed(msg)
|
||||
|
||||
case FileUploadedMsg:
|
||||
return m.handleFileUploaded(msg)
|
||||
|
||||
case inputReadyMsg:
|
||||
m.initInputReady = true
|
||||
m.input.textInput.Focus()
|
||||
m.input.textInput.SetValue("")
|
||||
return m, m.input.textInput.Cursor.BlinkCmd()
|
||||
|
||||
case resetQuitMsg:
|
||||
m.quitPending = false
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Only forward messages to the text input after it's been focused
|
||||
if m.splashShown {
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// View renders the UI.
|
||||
func (m Model) View() string {
|
||||
if m.width == 0 || m.height == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
separator := lipgloss.NewStyle().Foreground(separatorColor).Render(
|
||||
strings.Repeat("─", m.width),
|
||||
)
|
||||
|
||||
// Calculate heights: separator(1) + menu + file badges + input(1) + separator(1) + status(1)
|
||||
menuView := m.input.viewMenu(m.width)
|
||||
menuHeight := 0
|
||||
if menuView != "" {
|
||||
menuHeight = strings.Count(menuView, "\n") + 1
|
||||
}
|
||||
|
||||
fileHeight := 0
|
||||
if len(m.input.attachedFiles) > 0 {
|
||||
fileHeight = 1
|
||||
}
|
||||
|
||||
bottomHeight := 1 + menuHeight + fileHeight + 1 + 1 + 1 // sep + menu + files + input + sep + status
|
||||
viewportHeight := m.height - bottomHeight
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = 1
|
||||
}
|
||||
|
||||
var parts []string
|
||||
parts = append(parts, m.viewport.view(viewportHeight))
|
||||
parts = append(parts, separator)
|
||||
if menuView != "" {
|
||||
parts = append(parts, menuView)
|
||||
}
|
||||
parts = append(parts, m.input.viewInput())
|
||||
parts = append(parts, separator)
|
||||
parts = append(parts, m.status.view())
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// handleKey processes keyboard input.
|
||||
func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyEscape:
|
||||
// Cancel streaming or close menu
|
||||
if m.input.menuVisible {
|
||||
m.input.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
// Dismiss picker
|
||||
if m.viewport.pickerActive {
|
||||
m.viewport.pickerActive = false
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyCtrlD:
|
||||
// If streaming, cancel first; require a fresh Ctrl+D pair to quit
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
if m.quitPending {
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.quitPending = true
|
||||
m.viewport.addInfo("Press Ctrl+D again to quit.")
|
||||
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
|
||||
return resetQuitMsg{}
|
||||
})
|
||||
|
||||
case tea.KeyCtrlO:
|
||||
m.viewport.showSources = !m.viewport.showSources
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
// If picker is active, handle selection
|
||||
if m.viewport.pickerActive && len(m.viewport.pickerItems) > 0 {
|
||||
item := m.viewport.pickerItems[m.viewport.pickerIndex]
|
||||
m.viewport.pickerActive = false
|
||||
switch m.viewport.pickerType {
|
||||
case pickerSession:
|
||||
return cmdResume(m, item.id)
|
||||
case pickerAgent:
|
||||
return cmdSelectAgent(m, item.id)
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.viewport.pickerActive && m.viewport.pickerIndex > 0 {
|
||||
m.viewport.pickerIndex--
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.viewport.pickerActive && m.viewport.pickerIndex < len(m.viewport.pickerItems)-1 {
|
||||
m.viewport.pickerIndex++
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
viewportHeight := m.height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = 1
|
||||
}
|
||||
m.viewport.scrollUp(viewportHeight / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyPgDown:
|
||||
viewportHeight := m.height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = 1
|
||||
}
|
||||
m.viewport.scrollDown(viewportHeight / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Pass to input
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) handleSubmit(text string) (tea.Model, tea.Cmd) {
|
||||
if strings.HasPrefix(text, "/") {
|
||||
return handleSlashCommand(m, text)
|
||||
}
|
||||
return m.sendMessage(text)
|
||||
}
|
||||
|
||||
func (m Model) handleFileDrop(path string) (tea.Model, tea.Cmd) {
|
||||
return cmdAttach(m, path)
|
||||
}
|
||||
|
||||
func (m Model) cancelStream() (Model, tea.Cmd) {
|
||||
if m.streamCancel != nil {
|
||||
m.streamCancel()
|
||||
}
|
||||
if m.chatSessionID != nil {
|
||||
sid := *m.chatSessionID
|
||||
go m.client.StopChatSession(sid)
|
||||
}
|
||||
m, cmd := m.finishStream(nil)
|
||||
m.viewport.addInfo("Generation stopped.")
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) sendMessage(message string) (Model, tea.Cmd) {
|
||||
if m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addUserMessage(message)
|
||||
m.viewport.startAgent()
|
||||
|
||||
// Prepare file descriptors
|
||||
fileDescs := make([]models.FileDescriptorPayload, len(m.attachedFiles))
|
||||
copy(fileDescs, m.attachedFiles)
|
||||
m.attachedFiles = nil
|
||||
m.input.clearFiles()
|
||||
|
||||
m.isStreaming = true
|
||||
m.agentStarted = false
|
||||
m.citations = make(map[int]string)
|
||||
m.status.setStreaming(true)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.streamCancel = cancel
|
||||
|
||||
ch := m.client.SendMessageStream(
|
||||
ctx,
|
||||
message,
|
||||
m.chatSessionID,
|
||||
m.agentID,
|
||||
m.parentMessageID,
|
||||
fileDescs,
|
||||
)
|
||||
m.streamCh = ch
|
||||
|
||||
return m, api.WaitForStreamEvent(ch)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamEvent(msg api.StreamEventMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore stale events after cancellation
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
if msg.Event == nil {
|
||||
return m, api.WaitForStreamEvent(m.streamCh)
|
||||
}
|
||||
|
||||
switch e := msg.Event.(type) {
|
||||
case models.SessionCreatedEvent:
|
||||
m.chatSessionID = &e.ChatSessionID
|
||||
m.needsRename = true
|
||||
m.status.setSession(e.ChatSessionID)
|
||||
|
||||
case models.MessageIDEvent:
|
||||
m.parentMessageID = &e.ReservedAgentMessageID
|
||||
|
||||
case models.MessageStartEvent:
|
||||
m.agentStarted = true
|
||||
|
||||
case models.MessageDeltaEvent:
|
||||
m.agentStarted = true
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.SearchStartEvent:
|
||||
if e.IsInternetSearch {
|
||||
m.viewport.addInfo("Web search…")
|
||||
} else {
|
||||
m.viewport.addInfo("Searching…")
|
||||
}
|
||||
|
||||
case models.SearchQueriesEvent:
|
||||
if len(e.Queries) > 0 {
|
||||
queries := e.Queries
|
||||
if len(queries) > 3 {
|
||||
queries = queries[:3]
|
||||
}
|
||||
parts := make([]string, len(queries))
|
||||
for i, q := range queries {
|
||||
parts[i] = "\"" + q + "\""
|
||||
}
|
||||
m.viewport.addInfo("Searching: " + strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
case models.SearchDocumentsEvent:
|
||||
count := len(e.Documents)
|
||||
suffix := "s"
|
||||
if count == 1 {
|
||||
suffix = ""
|
||||
}
|
||||
m.viewport.addInfo("Found " + strconv.Itoa(count) + " document" + suffix)
|
||||
|
||||
case models.ReasoningStartEvent:
|
||||
m.viewport.addInfo("Thinking…")
|
||||
|
||||
case models.ReasoningDeltaEvent:
|
||||
// We don't display reasoning text, just the indicator
|
||||
|
||||
case models.ReasoningDoneEvent:
|
||||
// No-op
|
||||
|
||||
case models.CitationEvent:
|
||||
m.citations[e.CitationNumber] = e.DocumentID
|
||||
|
||||
case models.ToolStartEvent:
|
||||
m.viewport.addInfo("Using " + e.ToolName + "…")
|
||||
|
||||
case models.ResearchAgentStartEvent:
|
||||
m.viewport.addInfo("Researching: " + e.ResearchTask)
|
||||
|
||||
case models.DeepResearchPlanDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.IntermediateReportDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.StopEvent:
|
||||
return m.finishStream(nil)
|
||||
|
||||
case models.ErrorEvent:
|
||||
m.viewport.addError(e.Error)
|
||||
return m.finishStream(nil)
|
||||
}
|
||||
|
||||
return m, api.WaitForStreamEvent(m.streamCh)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamDone(msg api.StreamDoneMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore if already cancelled
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
return m.finishStream(msg.Err)
|
||||
}
|
||||
|
||||
func (m Model) finishStream(err error) (Model, tea.Cmd) {
|
||||
if m.agentStarted {
|
||||
m.viewport.finishAgent()
|
||||
if len(m.citations) > 0 {
|
||||
m.viewport.addCitations(m.citations)
|
||||
}
|
||||
}
|
||||
m.isStreaming = false
|
||||
m.agentStarted = false
|
||||
m.status.setStreaming(false)
|
||||
m.streamCancel = nil
|
||||
m.streamCh = nil
|
||||
|
||||
// Auto-rename new sessions
|
||||
if m.needsRename && m.chatSessionID != nil {
|
||||
m.needsRename = false
|
||||
sessionID := *m.chatSessionID
|
||||
client := m.client
|
||||
go func() {
|
||||
_, _ = client.RenameChatSession(sessionID, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleInitDone(msg InitDoneMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addWarning("Could not load agents. Using default.")
|
||||
} else {
|
||||
m.agents = msg.Agents
|
||||
for _, p := range m.agents {
|
||||
if p.ID == m.agentID {
|
||||
m.agentName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.status.setServer(m.config.ServerURL)
|
||||
m.status.setAgent(m.agentName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleAgentsLoaded(msg AgentsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load agents: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.agents = msg.Agents
|
||||
if len(m.agents) == 0 {
|
||||
m.viewport.addInfo("No agents available.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select an agent (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for _, p := range m.agents {
|
||||
label := fmt.Sprintf("%d: %s", p.ID, p.Name)
|
||||
if p.ID == m.agentID {
|
||||
label += " *"
|
||||
}
|
||||
desc := p.Description
|
||||
if len(desc) > 50 {
|
||||
desc = desc[:50] + "..."
|
||||
}
|
||||
if desc != "" {
|
||||
label += " - " + desc
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: strconv.Itoa(p.ID),
|
||||
label: label,
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerAgent, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionsLoaded(msg SessionsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load sessions: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
if len(msg.Sessions) == 0 {
|
||||
m.viewport.addInfo("No previous sessions found.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select a session to resume (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for i, s := range msg.Sessions {
|
||||
if i >= 15 {
|
||||
break
|
||||
}
|
||||
name := "Untitled"
|
||||
if s.Name != nil && *s.Name != "" {
|
||||
name = *s.Name
|
||||
}
|
||||
sid := s.ID
|
||||
if len(sid) > 8 {
|
||||
sid = sid[:8]
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: s.ID,
|
||||
label: sid + " " + name + " (" + s.Created + ")",
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerSession, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionResumed(msg SessionResumedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load session: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
|
||||
detail := msg.Detail
|
||||
m.chatSessionID = &detail.ChatSessionID
|
||||
m.viewport.clearDisplay()
|
||||
m.status.setSession(detail.ChatSessionID)
|
||||
|
||||
if detail.AgentName != nil {
|
||||
m.agentName = *detail.AgentName
|
||||
m.status.setAgent(*detail.AgentName)
|
||||
}
|
||||
if detail.AgentID != nil {
|
||||
m.agentID = *detail.AgentID
|
||||
}
|
||||
|
||||
// Replay messages
|
||||
for _, msg := range detail.Messages {
|
||||
switch msg.MessageType {
|
||||
case "user":
|
||||
m.viewport.addUserMessage(msg.Message)
|
||||
case "assistant":
|
||||
m.viewport.startAgent()
|
||||
m.viewport.appendToken(msg.Message)
|
||||
m.viewport.finishAgent()
|
||||
}
|
||||
}
|
||||
|
||||
// Set parent to last message
|
||||
if len(detail.Messages) > 0 {
|
||||
lastID := detail.Messages[len(detail.Messages)-1].MessageID
|
||||
m.parentMessageID = &lastID
|
||||
}
|
||||
|
||||
desc := "Untitled"
|
||||
if detail.Description != nil && *detail.Description != "" {
|
||||
desc = *detail.Description
|
||||
}
|
||||
m.viewport.addInfo("Resumed session: " + desc)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleFileUploaded(msg FileUploadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Upload failed: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.attachedFiles = append(m.attachedFiles, *msg.Descriptor)
|
||||
m.input.addFile(msg.FileName)
|
||||
m.viewport.addInfo("Attached: " + msg.FileName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type inputReadyMsg struct{}
|
||||
type resetQuitMsg struct{}
|
||||
|
||||
197
cli/internal/tui/commands.go
Normal file
197
cli/internal/tui/commands.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
)
|
||||
|
||||
// handleSlashCommand dispatches slash commands and returns updated model + cmd.
|
||||
func handleSlashCommand(m Model, text string) (Model, tea.Cmd) {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
command := strings.ToLower(parts[0])
|
||||
arg := ""
|
||||
if len(parts) > 1 {
|
||||
arg = parts[1]
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "/help":
|
||||
m.viewport.addInfo(helpText)
|
||||
return m, nil
|
||||
|
||||
case "/new":
|
||||
return cmdNew(m)
|
||||
|
||||
case "/agent":
|
||||
if arg != "" {
|
||||
return cmdSelectAgent(m, arg)
|
||||
}
|
||||
return cmdShowAgents(m)
|
||||
|
||||
case "/attach":
|
||||
return cmdAttach(m, arg)
|
||||
|
||||
case "/sessions", "/resume":
|
||||
if strings.TrimSpace(arg) != "" {
|
||||
return cmdResume(m, arg)
|
||||
}
|
||||
return cmdSessions(m)
|
||||
|
||||
case "/configure":
|
||||
m.viewport.addInfo("Run 'onyx-cli configure' to change connection settings.")
|
||||
return m, nil
|
||||
|
||||
case "/clear":
|
||||
m.viewport.clearDisplay()
|
||||
return m, nil
|
||||
|
||||
case "/connectors":
|
||||
url := m.config.ServerURL + "/admin/indexing/status"
|
||||
util.OpenBrowser(url)
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
return m, nil
|
||||
|
||||
case "/settings":
|
||||
url := m.config.ServerURL + "/app/settings/general"
|
||||
util.OpenBrowser(url)
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
return m, nil
|
||||
|
||||
case "/quit":
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.viewport.addWarning(fmt.Sprintf("Unknown command: %s. Type /help for available commands.", command))
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
func cmdNew(m Model) (Model, tea.Cmd) {
|
||||
m.chatSessionID = nil
|
||||
parentID := -1
|
||||
m.parentMessageID = &parentID
|
||||
m.needsRename = false
|
||||
m.citations = nil
|
||||
m.viewport.clearAll()
|
||||
// Re-add splash as a scrollable entry
|
||||
viewportHeight := m.height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = m.height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
m.status.setSession("")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdShowAgents(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading agents...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return AgentsLoadedMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSelectAgent(m Model, idStr string) (Model, tea.Cmd) {
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(idStr))
|
||||
if err != nil {
|
||||
m.viewport.addWarning("Invalid agent ID. Use a number.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var target *models.AgentSummary
|
||||
for i := range m.agents {
|
||||
if m.agents[i].ID == pid {
|
||||
target = &m.agents[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
m.viewport.addWarning(fmt.Sprintf("Agent %d not found. Use /agent to see available agents.", pid))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.agentID = target.ID
|
||||
m.agentName = target.Name
|
||||
m.status.setAgent(target.Name)
|
||||
m.viewport.addInfo("Switched to agent: " + target.Name)
|
||||
|
||||
// Save preference
|
||||
m.config.DefaultAgentID = target.ID
|
||||
_ = config.Save(m.config)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdAttach(m Model, pathStr string) (Model, tea.Cmd) {
|
||||
if pathStr == "" {
|
||||
m.viewport.addWarning("Usage: /attach <file_path>")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Uploading " + pathStr + "...")
|
||||
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
fd, err := client.UploadFile(pathStr)
|
||||
if err != nil {
|
||||
return FileUploadedMsg{Err: err, FileName: pathStr}
|
||||
}
|
||||
return FileUploadedMsg{Descriptor: fd, FileName: pathStr}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSessions(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading sessions...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
sessions, err := client.ListChatSessions()
|
||||
return SessionsLoadedMsg{Sessions: sessions, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdResume(m Model, sessionIDStr string) (Model, tea.Cmd) {
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
// Try to find session by prefix match
|
||||
sessions, err := client.ListChatSessions()
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: err}
|
||||
}
|
||||
|
||||
var targetID string
|
||||
for _, s := range sessions {
|
||||
if strings.HasPrefix(s.ID, sessionIDStr) {
|
||||
targetID = s.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetID == "" {
|
||||
// Try as full UUID
|
||||
targetID = sessionIDStr
|
||||
}
|
||||
|
||||
detail, err := client.GetChatSession(targetID)
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: fmt.Errorf("session not found: %s", sessionIDStr)}
|
||||
}
|
||||
return SessionResumedMsg{Detail: detail}
|
||||
}
|
||||
}
|
||||
|
||||
// loadAgentsCmd returns a tea.Cmd that loads agents from the API.
|
||||
func loadAgentsCmd(client *api.Client) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return InitDoneMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
24
cli/internal/tui/help.go
Normal file
24
cli/internal/tui/help.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package tui
|
||||
|
||||
const helpText = `Onyx CLI Commands
|
||||
|
||||
/help Show this help message
|
||||
/new Start a new chat session
|
||||
/agent List and switch agents
|
||||
/attach <path> Attach a file to next message
|
||||
/sessions Browse and resume previous sessions
|
||||
/clear Clear the chat display
|
||||
/configure Re-run connection setup
|
||||
/connectors Open connectors page in browser
|
||||
/settings Open Onyx settings in browser
|
||||
/quit Exit Onyx CLI
|
||||
|
||||
Keyboard Shortcuts
|
||||
|
||||
Enter Send message
|
||||
Escape Cancel current generation
|
||||
Ctrl+O Toggle source citations
|
||||
Ctrl+D Quit (press twice)
|
||||
Scroll Up/Down Mouse wheel or Shift+Up/Down
|
||||
Page Up/Down Scroll half page
|
||||
`
|
||||
237
cli/internal/tui/input.go
Normal file
237
cli/internal/tui/input.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// slashCommand defines a slash command with its description.
|
||||
type slashCommand struct {
|
||||
command string
|
||||
description string
|
||||
}
|
||||
|
||||
var slashCommands = []slashCommand{
|
||||
{"/help", "Show help message"},
|
||||
{"/new", "Start a new chat session"},
|
||||
{"/agent", "List and switch agents"},
|
||||
{"/attach", "Attach a file to next message"},
|
||||
{"/sessions", "Browse and resume previous sessions"},
|
||||
{"/clear", "Clear the chat display"},
|
||||
{"/configure", "Re-run connection setup"},
|
||||
{"/connectors", "Open connectors in browser"},
|
||||
{"/settings", "Open settings in browser"},
|
||||
{"/quit", "Exit Onyx CLI"},
|
||||
}
|
||||
|
||||
// Commands that take arguments (filled in with trailing space on Tab/Enter).
|
||||
var argCommands = map[string]bool{
|
||||
"/attach": true,
|
||||
}
|
||||
|
||||
// inputModel manages the text input and slash command menu.
|
||||
type inputModel struct {
|
||||
textInput textinput.Model
|
||||
menuVisible bool
|
||||
menuItems []slashCommand
|
||||
menuIndex int
|
||||
attachedFiles []string
|
||||
}
|
||||
|
||||
func newInputModel() inputModel {
|
||||
ti := textinput.New()
|
||||
ti.Prompt = "" // We render our own prompt in viewInput()
|
||||
ti.Placeholder = "Send a message…"
|
||||
ti.CharLimit = 10000
|
||||
// Don't focus here — focus after first WindowSizeMsg to avoid
|
||||
// capturing terminal init escape sequences as input.
|
||||
|
||||
return inputModel{
|
||||
textInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m inputModel) update(msg tea.Msg) (inputModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) handleKey(msg tea.KeyMsg) (inputModel, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.menuVisible && m.menuIndex > 0 {
|
||||
m.menuIndex--
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyDown:
|
||||
if m.menuVisible && m.menuIndex < len(m.menuItems)-1 {
|
||||
m.menuIndex++
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyTab:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
} else {
|
||||
m.textInput.SetValue(cmd)
|
||||
m.textInput.SetCursor(len(cmd))
|
||||
}
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyEnter:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
// Execute immediately
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: cmd} }
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(m.textInput.Value())
|
||||
if text == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check for file path (drag-and-drop)
|
||||
if dropped := detectFileDrop(text); dropped != "" {
|
||||
m.textInput.SetValue("")
|
||||
return m, func() tea.Msg { return fileDropMsg{path: dropped} }
|
||||
}
|
||||
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: text} }
|
||||
|
||||
case tea.KeyEscape:
|
||||
if m.menuVisible {
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) updateMenu() inputModel {
|
||||
val := strings.TrimSpace(m.textInput.Value())
|
||||
if strings.HasPrefix(val, "/") && !strings.Contains(val, " ") {
|
||||
needle := strings.ToLower(val)
|
||||
var filtered []slashCommand
|
||||
for _, sc := range slashCommands {
|
||||
if strings.HasPrefix(sc.command, needle) {
|
||||
filtered = append(filtered, sc)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
m.menuVisible = true
|
||||
m.menuItems = filtered
|
||||
if m.menuIndex >= len(filtered) {
|
||||
m.menuIndex = 0
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *inputModel) addFile(name string) {
|
||||
m.attachedFiles = append(m.attachedFiles, name)
|
||||
}
|
||||
|
||||
func (m *inputModel) clearFiles() {
|
||||
m.attachedFiles = nil
|
||||
}
|
||||
|
||||
// submitMsg is sent when user submits text.
|
||||
type submitMsg struct {
|
||||
text string
|
||||
}
|
||||
|
||||
// fileDropMsg is sent when a file path is detected.
|
||||
type fileDropMsg struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// detectFileDrop checks if the text looks like a file path.
|
||||
func detectFileDrop(text string) string {
|
||||
cleaned := strings.Trim(text, "'\"")
|
||||
if cleaned == "" {
|
||||
return ""
|
||||
}
|
||||
// Expand ~ to home dir
|
||||
if strings.HasPrefix(cleaned, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
cleaned = filepath.Join(home, cleaned[1:])
|
||||
}
|
||||
}
|
||||
abs, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ""
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// viewMenu renders the slash command menu.
|
||||
func (m inputModel) viewMenu(width int) string {
|
||||
if !m.menuVisible || len(m.menuItems) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for i, item := range m.menuItems {
|
||||
prefix := " "
|
||||
if i == m.menuIndex {
|
||||
prefix = "> "
|
||||
}
|
||||
line := prefix + item.command + " " + statusMsgStyle.Render(item.description)
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// viewInput renders the input line with prompt and optional file badges.
|
||||
func (m inputModel) viewInput() string {
|
||||
var parts []string
|
||||
|
||||
if len(m.attachedFiles) > 0 {
|
||||
badges := strings.Join(m.attachedFiles, "] [")
|
||||
parts = append(parts, statusMsgStyle.Render("Attached: ["+badges+"]"))
|
||||
}
|
||||
|
||||
parts = append(parts, inputPrompt+m.textInput.View())
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
46
cli/internal/tui/messages.go
Normal file
46
cli/internal/tui/messages.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// InitDoneMsg signals that async initialization is complete.
|
||||
type InitDoneMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionsLoadedMsg carries loaded chat sessions.
|
||||
type SessionsLoadedMsg struct {
|
||||
Sessions []models.ChatSessionDetails
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionResumedMsg carries a loaded session detail.
|
||||
type SessionResumedMsg struct {
|
||||
Detail *models.ChatSessionDetailResponse
|
||||
Err error
|
||||
}
|
||||
|
||||
// FileUploadedMsg carries an uploaded file descriptor.
|
||||
type FileUploadedMsg struct {
|
||||
Descriptor *models.FileDescriptorPayload
|
||||
FileName string
|
||||
Err error
|
||||
}
|
||||
|
||||
// AgentsLoadedMsg carries freshly fetched agents from the API.
|
||||
type AgentsLoadedMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
|
||||
// InfoMsg is a simple informational message for display.
|
||||
type InfoMsg struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
// ErrorMsg wraps an error for display.
|
||||
type ErrorMsg struct {
|
||||
Err error
|
||||
}
|
||||
79
cli/internal/tui/splash.go
Normal file
79
cli/internal/tui/splash.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const onyxLogo = ` ██████╗ ███╗ ██╗██╗ ██╗██╗ ██╗
|
||||
██╔═══██╗████╗ ██║╚██╗ ██╔╝╚██╗██╔╝
|
||||
██║ ██║██╔██╗ ██║ ╚████╔╝ ╚███╔╝
|
||||
██║ ██║██║╚██╗██║ ╚██╔╝ ██╔██╗
|
||||
╚██████╔╝██║ ╚████║ ██║ ██╔╝ ██╗
|
||||
╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝`
|
||||
|
||||
const tagline = "Your terminal interface for Onyx"
|
||||
const splashHint = "Type a message to begin · /help for commands"
|
||||
|
||||
// renderSplash renders the splash screen centered for the given dimensions.
|
||||
func renderSplash(width, height int) string {
|
||||
// Render the logo as a single block (don't center individual lines)
|
||||
logo := splashStyle.Render(onyxLogo)
|
||||
|
||||
// Center tagline and hint relative to the logo block width
|
||||
logoWidth := lipgloss.Width(logo)
|
||||
tag := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
taglineStyle.Render(tagline),
|
||||
)
|
||||
hint := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
hintStyle.Render(splashHint),
|
||||
)
|
||||
|
||||
block := lipgloss.JoinVertical(lipgloss.Left, logo, "", tag, hint)
|
||||
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, block)
|
||||
}
|
||||
|
||||
// RenderSplashOnboarding renders splash for the terminal onboarding screen.
|
||||
func RenderSplashOnboarding(width, height int) string {
|
||||
// Render the logo as a styled block, then center it as a unit
|
||||
styledLogo := splashStyle.Render(onyxLogo)
|
||||
logoWidth := lipgloss.Width(styledLogo)
|
||||
logoLines := strings.Split(styledLogo, "\n")
|
||||
|
||||
logoHeight := len(logoLines)
|
||||
contentHeight := logoHeight + 2 // logo + blank + tagline
|
||||
topPad := (height - contentHeight) / 2
|
||||
if topPad < 1 {
|
||||
topPad = 1
|
||||
}
|
||||
|
||||
// Center the entire logo block horizontally
|
||||
blockPad := (width - logoWidth) / 2
|
||||
if blockPad < 0 {
|
||||
blockPad = 0
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < topPad; i++ {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
for _, line := range logoLines {
|
||||
b.WriteString(strings.Repeat(" ", blockPad))
|
||||
b.WriteString(line)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
b.WriteByte('\n')
|
||||
tagPad := (width - len(tagline)) / 2
|
||||
if tagPad < 0 {
|
||||
tagPad = 0
|
||||
}
|
||||
b.WriteString(strings.Repeat(" ", tagPad))
|
||||
b.WriteString(taglineStyle.Render(tagline))
|
||||
b.WriteByte('\n')
|
||||
|
||||
return b.String()
|
||||
}
|
||||
60
cli/internal/tui/statusbar.go
Normal file
60
cli/internal/tui/statusbar.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// statusBar manages the footer status display.
|
||||
type statusBar struct {
|
||||
agentName string
|
||||
serverURL string
|
||||
sessionID string
|
||||
streaming bool
|
||||
width int
|
||||
}
|
||||
|
||||
func newStatusBar() statusBar {
|
||||
return statusBar{
|
||||
agentName: "Default",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusBar) setAgent(name string) { s.agentName = name }
|
||||
func (s *statusBar) setServer(url string) { s.serverURL = url }
|
||||
func (s *statusBar) setSession(id string) {
|
||||
if len(id) > 8 {
|
||||
id = id[:8]
|
||||
}
|
||||
s.sessionID = id
|
||||
}
|
||||
func (s *statusBar) setStreaming(v bool) { s.streaming = v }
|
||||
func (s *statusBar) setWidth(w int) { s.width = w }
|
||||
|
||||
func (s statusBar) view() string {
|
||||
var leftParts []string
|
||||
if s.serverURL != "" {
|
||||
leftParts = append(leftParts, s.serverURL)
|
||||
}
|
||||
name := s.agentName
|
||||
if name == "" {
|
||||
name = "Default"
|
||||
}
|
||||
leftParts = append(leftParts, name)
|
||||
left := statusBarStyle.Render(strings.Join(leftParts, " · "))
|
||||
|
||||
right := "Ctrl+D to quit"
|
||||
if s.streaming {
|
||||
right = "Esc to cancel"
|
||||
}
|
||||
rightRendered := statusBarStyle.Render(right)
|
||||
|
||||
// Fill space between left and right
|
||||
gap := s.width - lipgloss.Width(left) - lipgloss.Width(rightRendered)
|
||||
if gap < 1 {
|
||||
gap = 1
|
||||
}
|
||||
|
||||
return left + strings.Repeat(" ", gap) + rightRendered
|
||||
}
|
||||
29
cli/internal/tui/styles.go
Normal file
29
cli/internal/tui/styles.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tui
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
var (
|
||||
// Colors
|
||||
accentColor = lipgloss.Color("#6c8ebf")
|
||||
dimColor = lipgloss.Color("#555577")
|
||||
errorColor = lipgloss.Color("#ff5555")
|
||||
splashColor = lipgloss.Color("#7C6AEF")
|
||||
separatorColor = lipgloss.Color("#333355")
|
||||
citationColor = lipgloss.Color("#666688")
|
||||
|
||||
// Styles
|
||||
userPrefixStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
agentDot = lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("◉")
|
||||
infoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#b0b0cc"))
|
||||
dimInfoStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
statusMsgStyle = dimInfoStyle // used for slash menu descriptions, file badges
|
||||
errorStyle = lipgloss.NewStyle().Foreground(errorColor).Bold(true)
|
||||
warnStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
citationStyle = lipgloss.NewStyle().Foreground(citationColor)
|
||||
statusBarStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
inputPrompt = lipgloss.NewStyle().Foreground(accentColor).Render("❯ ")
|
||||
|
||||
splashStyle = lipgloss.NewStyle().Foreground(splashColor).Bold(true)
|
||||
taglineStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#A0A0A0"))
|
||||
hintStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
)
|
||||
400
cli/internal/tui/viewport.go
Normal file
400
cli/internal/tui/viewport.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/styles"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// entryKind is the type of chat entry.
|
||||
type entryKind int
|
||||
|
||||
const (
|
||||
entryUser entryKind = iota
|
||||
entryAgent
|
||||
entryInfo
|
||||
entryError
|
||||
entryCitation
|
||||
)
|
||||
|
||||
// chatEntry is a single rendered entry in the chat history.
|
||||
type chatEntry struct {
|
||||
kind entryKind
|
||||
content string // raw content (for agent: the markdown source)
|
||||
rendered string // pre-rendered output
|
||||
citations []string // citation lines (for citation entries)
|
||||
}
|
||||
|
||||
// pickerKind distinguishes what the picker is selecting.
|
||||
type pickerKind int
|
||||
|
||||
const (
|
||||
pickerSession pickerKind = iota
|
||||
pickerAgent
|
||||
)
|
||||
|
||||
// pickerItem is a selectable item in the picker.
|
||||
type pickerItem struct {
|
||||
id string
|
||||
label string
|
||||
}
|
||||
|
||||
// viewport manages the chat display.
|
||||
type viewport struct {
|
||||
entries []chatEntry
|
||||
width int
|
||||
streaming bool
|
||||
streamBuf string
|
||||
showSources bool
|
||||
renderer *glamour.TermRenderer
|
||||
pickerItems []pickerItem
|
||||
pickerActive bool
|
||||
pickerIndex int
|
||||
pickerType pickerKind
|
||||
scrollOffset int // lines scrolled up from bottom (0 = pinned to bottom)
|
||||
}
|
||||
|
||||
// newMarkdownRenderer creates a Glamour renderer with zero left margin.
|
||||
func newMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
style := styles.DarkStyleConfig
|
||||
zero := uint(0)
|
||||
style.Document.Margin = &zero
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(style),
|
||||
glamour.WithWordWrap(width-4),
|
||||
)
|
||||
return r
|
||||
}
|
||||
|
||||
func newViewport(width int) *viewport {
|
||||
return &viewport{
|
||||
width: width,
|
||||
renderer: newMarkdownRenderer(width),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) addSplash(height int) {
|
||||
splash := renderSplash(v.width, height)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
rendered: splash,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) setWidth(w int) {
|
||||
v.width = w
|
||||
v.renderer = newMarkdownRenderer(w)
|
||||
}
|
||||
|
||||
func (v *viewport) addUserMessage(msg string) {
|
||||
rendered := "\n" + userPrefixStyle.Render("❯ ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryUser,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) startAgent() {
|
||||
v.streaming = true
|
||||
v.streamBuf = ""
|
||||
// Add a blank-line spacer entry before the agent message
|
||||
v.entries = append(v.entries, chatEntry{kind: entryInfo, rendered: ""})
|
||||
}
|
||||
|
||||
func (v *viewport) appendToken(token string) {
|
||||
v.streamBuf += token
|
||||
v.scrollOffset = 0 // auto-scroll to bottom on new content
|
||||
}
|
||||
|
||||
func (v *viewport) finishAgent() {
|
||||
if v.streamBuf == "" {
|
||||
v.streaming = false
|
||||
return
|
||||
}
|
||||
|
||||
// Render markdown with Glamour (zero left margin style)
|
||||
rendered := v.renderMarkdown(v.streamBuf)
|
||||
rendered = strings.TrimLeft(rendered, "\n")
|
||||
rendered = strings.TrimRight(rendered, "\n")
|
||||
lines := strings.Split(rendered, "\n")
|
||||
// Prefix first line with dot, indent continuation lines
|
||||
if len(lines) > 0 {
|
||||
lines[0] = agentDot + " " + lines[0]
|
||||
for i := 1; i < len(lines); i++ {
|
||||
lines[i] = " " + lines[i]
|
||||
}
|
||||
}
|
||||
rendered = strings.Join(lines, "\n")
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryAgent,
|
||||
content: v.streamBuf,
|
||||
rendered: rendered,
|
||||
})
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
}
|
||||
|
||||
func (v *viewport) renderMarkdown(md string) string {
|
||||
if v.renderer == nil {
|
||||
return md
|
||||
}
|
||||
out, err := v.renderer.Render(md)
|
||||
if err != nil {
|
||||
return md
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *viewport) addInfo(msg string) {
|
||||
rendered := infoStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addWarning(msg string) {
|
||||
rendered := warnStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addError(msg string) {
|
||||
rendered := errorStyle.Render("● Error: ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addCitations(citations map[int]string) {
|
||||
if len(citations) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]int, 0, len(citations))
|
||||
for k := range citations {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Ints(keys)
|
||||
var parts []string
|
||||
for _, num := range keys {
|
||||
parts = append(parts, fmt.Sprintf("[%d] %s", num, citations[num]))
|
||||
}
|
||||
text := fmt.Sprintf("Sources (%d): %s", len(citations), strings.Join(parts, " "))
|
||||
var citLines []string
|
||||
citLines = append(citLines, text)
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryCitation,
|
||||
content: text,
|
||||
rendered: citationStyle.Render("● "+text),
|
||||
citations: citLines,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) showPicker(kind pickerKind, items []pickerItem) {
|
||||
v.pickerItems = items
|
||||
v.pickerType = kind
|
||||
v.pickerActive = true
|
||||
v.pickerIndex = 0
|
||||
}
|
||||
|
||||
func (v *viewport) scrollUp(n int) {
|
||||
v.scrollOffset += n
|
||||
// Clamped in view() since we need to know total content height
|
||||
}
|
||||
|
||||
func (v *viewport) scrollDown(n int) {
|
||||
v.scrollOffset -= n
|
||||
if v.scrollOffset < 0 {
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) clearAll() {
|
||||
v.entries = nil
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
v.pickerItems = nil
|
||||
v.pickerActive = false
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
|
||||
func (v *viewport) clearDisplay() {
|
||||
v.entries = nil
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
|
||||
// pickerTitle returns a title for the current picker kind.
|
||||
func (v *viewport) pickerTitle() string {
|
||||
switch v.pickerType {
|
||||
case pickerAgent:
|
||||
return "Select Agent"
|
||||
case pickerSession:
|
||||
return "Resume Session"
|
||||
default:
|
||||
return "Select"
|
||||
}
|
||||
}
|
||||
|
||||
// renderPicker renders the picker as a bordered overlay.
|
||||
func (v *viewport) renderPicker(width, height int) string {
|
||||
title := v.pickerTitle()
|
||||
|
||||
// Determine picker dimensions
|
||||
maxItems := len(v.pickerItems)
|
||||
panelWidth := width - 4
|
||||
if panelWidth < 30 {
|
||||
panelWidth = 30
|
||||
}
|
||||
if panelWidth > 70 {
|
||||
panelWidth = 70
|
||||
}
|
||||
innerWidth := panelWidth - 4 // border + padding
|
||||
|
||||
// Visible window of items (scroll if too many)
|
||||
maxVisible := height - 6 // room for border, title, hint
|
||||
if maxVisible < 3 {
|
||||
maxVisible = 3
|
||||
}
|
||||
if maxVisible > maxItems {
|
||||
maxVisible = maxItems
|
||||
}
|
||||
|
||||
// Calculate scroll window around current index
|
||||
startIdx := 0
|
||||
if v.pickerIndex >= maxVisible {
|
||||
startIdx = v.pickerIndex - maxVisible + 1
|
||||
}
|
||||
endIdx := startIdx + maxVisible
|
||||
if endIdx > maxItems {
|
||||
endIdx = maxItems
|
||||
startIdx = endIdx - maxVisible
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
}
|
||||
|
||||
var itemLines []string
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
item := v.pickerItems[i]
|
||||
label := item.label
|
||||
if len(label) > innerWidth-4 {
|
||||
label = label[:innerWidth-7] + "..."
|
||||
}
|
||||
if i == v.pickerIndex {
|
||||
line := lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("> " + label)
|
||||
itemLines = append(itemLines, line)
|
||||
} else {
|
||||
itemLines = append(itemLines, " "+label)
|
||||
}
|
||||
}
|
||||
|
||||
hint := lipgloss.NewStyle().Foreground(dimColor).Render("↑↓ navigate • enter select • esc cancel")
|
||||
|
||||
body := strings.Join(itemLines, "\n") + "\n\n" + hint
|
||||
|
||||
panel := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accentColor).
|
||||
Padding(1, 2).
|
||||
Width(panelWidth).
|
||||
Render(body)
|
||||
|
||||
titleRendered := lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Bold(true).
|
||||
Render(" " + title + " ")
|
||||
|
||||
// Place title on top border
|
||||
panelLines := strings.Split(panel, "\n")
|
||||
if len(panelLines) > 0 {
|
||||
border := panelLines[0]
|
||||
runes := []rune(border)
|
||||
if len(runes) > 4 {
|
||||
// Insert title after the 2nd rune of the border
|
||||
titleRunes := []rune(titleRendered)
|
||||
panelLines[0] = string(runes[:2]) + string(titleRunes) + string(runes[2:])
|
||||
}
|
||||
}
|
||||
panel = strings.Join(panelLines, "\n")
|
||||
|
||||
// Center the panel in the viewport
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, panel)
|
||||
}
|
||||
|
||||
// view renders the full viewport content.
|
||||
func (v *viewport) view(height int) string {
|
||||
// If picker is active, render it as an overlay
|
||||
if v.pickerActive && len(v.pickerItems) > 0 {
|
||||
return v.renderPicker(v.width, height)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
for _, e := range v.entries {
|
||||
if e.kind == entryCitation && !v.showSources {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, e.rendered)
|
||||
}
|
||||
|
||||
// Streaming buffer (plain text, not markdown)
|
||||
if v.streaming && v.streamBuf != "" {
|
||||
bufLines := strings.Split(v.streamBuf, "\n")
|
||||
if len(bufLines) > 0 {
|
||||
bufLines[0] = agentDot + " " + bufLines[0]
|
||||
for i := 1; i < len(bufLines); i++ {
|
||||
bufLines[i] = " " + bufLines[i]
|
||||
}
|
||||
}
|
||||
lines = append(lines, strings.Join(bufLines, "\n"))
|
||||
} else if v.streaming {
|
||||
lines = append(lines, agentDot+" ")
|
||||
}
|
||||
|
||||
content := strings.Join(lines, "\n")
|
||||
contentLines := strings.Split(content, "\n")
|
||||
total := len(contentLines)
|
||||
|
||||
// Clamp scroll offset
|
||||
maxScroll := total - height
|
||||
if maxScroll < 0 {
|
||||
maxScroll = 0
|
||||
}
|
||||
if v.scrollOffset > maxScroll {
|
||||
v.scrollOffset = maxScroll
|
||||
}
|
||||
|
||||
if total <= height {
|
||||
// Content fits — pad with empty lines at top to push content down
|
||||
padding := make([]string, height-total)
|
||||
for i := range padding {
|
||||
padding[i] = ""
|
||||
}
|
||||
contentLines = append(padding, contentLines...)
|
||||
} else {
|
||||
// Show a window: end is (total - scrollOffset), start is (end - height)
|
||||
end := total - v.scrollOffset
|
||||
start := end - height
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
contentLines = contentLines[start:end]
|
||||
}
|
||||
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
264
cli/internal/tui/viewport_test.go
Normal file
264
cli/internal/tui/viewport_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stripANSI removes ANSI escape sequences for test comparisons.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
func stripANSI(s string) string {
|
||||
return ansiRegex.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
func TestAddUserMessage(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("hello world")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryUser {
|
||||
t.Errorf("expected entryUser, got %d", e.kind)
|
||||
}
|
||||
if e.content != "hello world" {
|
||||
t.Errorf("expected content 'hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "❯") {
|
||||
t.Errorf("expected rendered to contain ❯, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAndFinishAgent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
|
||||
if !v.streaming {
|
||||
t.Error("expected streaming to be true after startAgent")
|
||||
}
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 spacer entry, got %d", len(v.entries))
|
||||
}
|
||||
if v.entries[0].rendered != "" {
|
||||
t.Errorf("expected empty spacer, got %q", v.entries[0].rendered)
|
||||
}
|
||||
|
||||
v.appendToken("Hello ")
|
||||
v.appendToken("world")
|
||||
|
||||
if v.streamBuf != "Hello world" {
|
||||
t.Errorf("expected streamBuf 'Hello world', got %q", v.streamBuf)
|
||||
}
|
||||
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false after finishAgent")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after finish, got %q", v.streamBuf)
|
||||
}
|
||||
if len(v.entries) != 2 {
|
||||
t.Fatalf("expected 2 entries (spacer + agent), got %d", len(v.entries))
|
||||
}
|
||||
|
||||
e := v.entries[1]
|
||||
if e.kind != entryAgent {
|
||||
t.Errorf("expected entryAgent, got %d", e.kind)
|
||||
}
|
||||
if e.content != "Hello world" {
|
||||
t.Errorf("expected content 'Hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentNoPadding(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Test message")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
// First line should not start with plain spaces (ANSI codes are OK)
|
||||
plain := stripANSI(e.rendered)
|
||||
lines := strings.Split(plain, "\n")
|
||||
if strings.HasPrefix(lines[0], " ") {
|
||||
t.Errorf("first line should not start with spaces, got %q", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentMultiline(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Line one\n\nLine three")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
plain := stripANSI(e.rendered)
|
||||
// Glamour may merge or reformat lines; just check content is present
|
||||
if !strings.Contains(plain, "Line one") {
|
||||
t.Errorf("expected 'Line one' in rendered, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "Line three") {
|
||||
t.Errorf("expected 'Line three' in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false")
|
||||
}
|
||||
if len(v.entries) != 1 {
|
||||
t.Errorf("expected 1 entry (spacer only), got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddInfo(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("test info")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryInfo {
|
||||
t.Errorf("expected entryInfo, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("info should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddError(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addError("something broke")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryError {
|
||||
t.Errorf("expected entryError, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "something broke") {
|
||||
t.Errorf("expected error message in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitations(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{1: "doc-a", 2: "doc-b"})
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryCitation {
|
||||
t.Errorf("expected entryCitation, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Sources (2)") {
|
||||
t.Errorf("expected sources count in rendered, got %q", plain)
|
||||
}
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("citation should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitationsEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{})
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries for empty citations, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationVisibility(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
v.addCitations(map[int]string{1: "doc"})
|
||||
|
||||
v.showSources = false
|
||||
view := v.view(20)
|
||||
plain := stripANSI(view)
|
||||
if strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations hidden when showSources=false")
|
||||
}
|
||||
|
||||
v.showSources = true
|
||||
view = v.view(20)
|
||||
plain = stripANSI(view)
|
||||
if !strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations visible when showSources=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearAll(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.startAgent()
|
||||
v.appendToken("response")
|
||||
|
||||
v.clearAll()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearAll, got %d", len(v.entries))
|
||||
}
|
||||
if v.streaming {
|
||||
t.Error("expected streaming=false after clearAll")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after clearAll, got %q", v.streamBuf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDisplay(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.clearDisplay()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearDisplay, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewPadsShortContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
|
||||
view := v.view(10)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 10 {
|
||||
t.Errorf("expected 10 lines (padded), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewTruncatesTallContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
for i := 0; i < 20; i++ {
|
||||
v.addInfo("line")
|
||||
}
|
||||
|
||||
view := v.view(5)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 5 {
|
||||
t.Errorf("expected 5 lines (truncated), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
26
cli/internal/util/browser.go
Normal file
26
cli/internal/util/browser.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Package util provides shared utility functions.
|
||||
package util
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// OpenBrowser opens the given URL in the user's default browser.
|
||||
func OpenBrowser(url string) {
|
||||
var cmd *exec.Cmd
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
}
|
||||
if cmd != nil {
|
||||
if err := cmd.Start(); err == nil {
|
||||
// Reap the child process to avoid zombies.
|
||||
go cmd.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
7
cli/main.go
Normal file
7
cli/main.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package main
|
||||
|
||||
import "github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
|
||||
func main() {
|
||||
cmd.Execute()
|
||||
}
|
||||
@@ -163,3 +163,16 @@ Add clear comments:
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side
|
||||
effects. Essentially every piece of meaningful logic should exist within some
|
||||
function that has to be explicitly invoked. Acceptable exceptions to this may
|
||||
include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to
|
||||
exist in a file dedicated for manual execution (contains `if __name__ ==
|
||||
"__main__":`) which should not be imported by anything else.
|
||||
- Related to the above, do not conflate Python scripts you intend to run from
|
||||
the command line (contains `if __name__ == "__main__":`) with modules you
|
||||
intend to import from elsewhere. If for some unlikely reason they have to be
|
||||
the same file, any logic specific to executing the file (including imports)
|
||||
should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
@@ -468,7 +468,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -293,7 +293,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -298,7 +298,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -335,7 +335,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -232,7 +232,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
@@ -520,7 +520,7 @@ services:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
test: ["CMD", "mc", "ready", "local"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@@ -534,9 +534,10 @@ services:
|
||||
required: false
|
||||
|
||||
# Below is needed for the `docker-out-of-docker` execution mode
|
||||
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
|
||||
user: root
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
|
||||
|
||||
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
|
||||
# privileged: true
|
||||
|
||||
@@ -10,7 +10,7 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aioboto3==15.1.0",
|
||||
"cohere==5.6.1",
|
||||
"fastapi==0.128.0",
|
||||
"fastapi==0.133.1",
|
||||
"google-cloud-aiplatform==1.121.0",
|
||||
"google-genai==1.52.0",
|
||||
"litellm==1.81.6",
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.6.2",
|
||||
"pypdf==6.7.3",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
|
||||
@@ -51,6 +51,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewRunCICommand())
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
159
tools/ods/cmd/whois.go
Normal file
159
tools/ods/cmd/whois.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/kube"
|
||||
)
|
||||
|
||||
var safeIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`)
|
||||
|
||||
// NewWhoisCommand creates the whois command for looking up users/tenants.
|
||||
func NewWhoisCommand() *cobra.Command {
|
||||
var ctx string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "whois <email-fragment or tenant-id>",
|
||||
Short: "Look up users and admins by email or tenant ID",
|
||||
Long: `Look up tenant and user information from the data plane PostgreSQL database.
|
||||
|
||||
Requires: AWS SSO login, kubectl access to the EKS cluster.
|
||||
|
||||
Two modes (auto-detected):
|
||||
|
||||
Email fragment:
|
||||
ods whois chris
|
||||
→ Searches user_tenant_mapping for emails matching '%chris%'
|
||||
|
||||
Tenant ID:
|
||||
ods whois tenant_abcd1234-...
|
||||
→ Lists all admin emails in that tenant
|
||||
|
||||
Cluster connection is configured via KUBE_CTX_* environment variables.
|
||||
Each variable is a space-separated tuple: "cluster region namespace"
|
||||
|
||||
export KUBE_CTX_DATA_PLANE="<cluster> <region> <namespace>"
|
||||
export KUBE_CTX_CONTROL_PLANE="<cluster> <region> <namespace>"
|
||||
etc...
|
||||
|
||||
Use -c to select which context (default: data_plane).`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runWhois(args[0], ctx)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&ctx, "context", "c", "data_plane", "cluster context name (maps to KUBE_CTX_<NAME> env var)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func clusterFromEnv(name string) *kube.Cluster {
|
||||
envKey := "KUBE_CTX_" + strings.ToUpper(name)
|
||||
val := os.Getenv(envKey)
|
||||
if val == "" {
|
||||
log.Fatalf("Environment variable %s is not set.\n\nSet it as a space-separated tuple:\n export %s=\"<cluster> <region> <namespace>\"", envKey, envKey)
|
||||
}
|
||||
|
||||
parts := strings.Fields(val)
|
||||
if len(parts) != 3 {
|
||||
log.Fatalf("%s must be a space-separated tuple of 3 values (cluster region namespace), got: %q", envKey, val)
|
||||
}
|
||||
|
||||
return &kube.Cluster{Name: parts[0], Region: parts[1], Namespace: parts[2]}
|
||||
}
|
||||
|
||||
// queryPod runs a SQL query via pginto on the given pod and returns cleaned output lines.
|
||||
func queryPod(c *kube.Cluster, pod, sql string) []string {
|
||||
raw, err := c.ExecOnPod(pod, "pginto", "-A", "-t", "-F", "\t", "-c", sql)
|
||||
if err != nil {
|
||||
log.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for _, line := range strings.Split(strings.TrimSpace(raw), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && !strings.HasPrefix(line, "Connecting to ") {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func runWhois(query string, ctx string) {
|
||||
c := clusterFromEnv(ctx)
|
||||
|
||||
if err := c.EnsureContext(); err != nil {
|
||||
log.Fatalf("Failed to ensure cluster context: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Finding api-server pod...")
|
||||
pod, err := c.FindPod("api-server")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find api-server pod: %v", err)
|
||||
}
|
||||
log.Debugf("Using pod: %s", pod)
|
||||
|
||||
if strings.HasPrefix(query, "tenant_") {
|
||||
findAdminsByTenant(c, pod, query)
|
||||
} else {
|
||||
findByEmail(c, pod, query)
|
||||
}
|
||||
}
|
||||
|
||||
func findByEmail(c *kube.Cluster, pod, fragment string) {
|
||||
fragment = strings.NewReplacer("'", "", `"`, "", `;`, "", `\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(fragment)
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email, tenant_id, active FROM public.user_tenant_mapping WHERE email LIKE '%%%s%%' ORDER BY email;`,
|
||||
fragment,
|
||||
)
|
||||
|
||||
log.Infof("Searching for emails matching '%%%s%%'...", fragment)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No results found.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "EMAIL\tTENANT ID\tACTIVE")
|
||||
_, _ = fmt.Fprintln(w, "-----\t---------\t------")
|
||||
for _, line := range lines {
|
||||
_, _ = fmt.Fprintln(w, line)
|
||||
}
|
||||
_ = w.Flush()
|
||||
}
|
||||
|
||||
func findAdminsByTenant(c *kube.Cluster, pod, tenantID string) {
|
||||
if !safeIdentifier.MatchString(tenantID) {
|
||||
log.Fatalf("Invalid tenant ID: %q (must be alphanumeric, hyphens, underscores only)", tenantID)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email FROM "%s"."user" WHERE role = 'ADMIN' AND is_active = true AND email NOT LIKE 'api_key__%%' ORDER BY email;`,
|
||||
tenantID,
|
||||
)
|
||||
|
||||
log.Infof("Fetching admin emails for %s...", tenantID)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No admin users found for this tenant.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("EMAIL")
|
||||
fmt.Println("-----")
|
||||
for _, line := range lines {
|
||||
fmt.Println(line)
|
||||
}
|
||||
}
|
||||
90
tools/ods/internal/kube/kube.go
Normal file
90
tools/ods/internal/kube/kube.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package kube
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cluster holds the connection info for a Kubernetes cluster.
|
||||
type Cluster struct {
|
||||
Name string
|
||||
Region string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
// EnsureContext makes sure the cluster exists in kubeconfig, calling
|
||||
// aws eks update-kubeconfig only if the context is missing.
|
||||
func (c *Cluster) EnsureContext() error {
|
||||
// Check if context already exists in kubeconfig
|
||||
cmd := exec.Command("kubectl", "config", "get-contexts", c.Name, "--no-headers")
|
||||
if err := cmd.Run(); err == nil {
|
||||
log.Debugf("Context %s already exists, skipping aws eks update-kubeconfig", c.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Context %s not found, fetching kubeconfig from AWS...", c.Name)
|
||||
cmd = exec.Command("aws", "eks", "update-kubeconfig", "--region", c.Region, "--name", c.Name, "--alias", c.Name)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("aws eks update-kubeconfig failed: %w\n%s", err, string(out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// kubectlArgs returns common kubectl flags to target this cluster without mutating global context.
|
||||
func (c *Cluster) kubectlArgs() []string {
|
||||
return []string{"--context", c.Name, "--namespace", c.Namespace}
|
||||
}
|
||||
|
||||
// FindPod returns the name of the first Running/Ready pod matching the given substring.
|
||||
func (c *Cluster) FindPod(substring string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "get", "po",
|
||||
"--field-selector", "status.phase=Running",
|
||||
"--no-headers",
|
||||
"-o", "custom-columns=NAME:.metadata.name,READY:.status.conditions[?(@.type=='Ready')].status",
|
||||
)
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("kubectl get po failed: %w\n%s", err, string(exitErr.Stderr))
|
||||
}
|
||||
return "", fmt.Errorf("kubectl get po failed: %w", err)
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
name, ready := fields[0], fields[1]
|
||||
if strings.Contains(name, substring) && ready == "True" {
|
||||
log.Debugf("Found pod: %s", name)
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no ready pod found matching %q", substring)
|
||||
}
|
||||
|
||||
// ExecOnPod runs a command on a pod and returns its stdout.
|
||||
func (c *Cluster) ExecOnPod(pod string, command ...string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "exec", pod, "--")
|
||||
args = append(args, command...)
|
||||
log.Debugf("Running: kubectl %s", strings.Join(args, " "))
|
||||
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("kubectl exec failed: %w\n%s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
}
|
||||
23
uv.lock
generated
23
uv.lock
generated
@@ -1688,17 +1688,18 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.133.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/6f/0eafed8349eea1fa462238b54a624c8b408cd1ba2795c8e64aa6c34f8ab7/fastapi-0.133.1.tar.gz", hash = "sha256:ed152a45912f102592976fde6cbce7dae1a8a1053da94202e51dd35d184fadd6", size = 378741, upload-time = "2026-02-25T18:18:17.398Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/c9/a175a7779f3599dfa4adfc97a6ce0e157237b3d7941538604aadaf97bfb6/fastapi-0.133.1-py3-none-any.whl", hash = "sha256:658f34ba334605b1617a65adf2ea6461901bdb9af3a3080d63ff791ecf7dc2e2", size = 109029, upload-time = "2026-02-25T18:18:18.578Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4612,7 +4613,7 @@ requires-dist = [
|
||||
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
|
||||
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
|
||||
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
|
||||
{ name = "fastapi", specifier = "==0.128.0" },
|
||||
{ name = "fastapi", specifier = "==0.133.1" },
|
||||
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
|
||||
@@ -4677,7 +4678,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.6.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.3" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -5924,11 +5925,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.6.2"
|
||||
version = "6.7.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/53/9b/63e767042fc852384dc71e5ff6f990ee4e1b165b1526cf3f9c23a4eebb47/pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2", size = 5303304, upload-time = "2026-02-24T17:23:11.42Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/90/3308a9b8b46c1424181fdf3f4580d2b423c5471425799e7fc62f92d183f4/pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a", size = 331263, upload-time = "2026-02-24T17:23:09.932Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8079,14 +8080,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.1.5"
|
||||
version = "3.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markupsafe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/61/f1/ee81806690a87dab5f5653c1f146c92bc066d7f4cebc603ef88eb9e13957/werkzeug-3.1.6.tar.gz", hash = "sha256:210c6bede5a420a913956b4791a7f4d6843a43b6fcee4dfa08a65e93007d0d25", size = 864736, upload-time = "2026-02-19T15:17:18.884Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user