mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-05 23:55:47 +00:00
Compare commits
4 Commits
worktree-o
...
docx_varia
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe7c02a3a9 | ||
|
|
ac9f5a5f1d | ||
|
|
5f6b348864 | ||
|
|
47bb69c3ea |
@@ -1,161 +0,0 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `MessageDeltaEvent` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `StopEvent` | Stream complete |
|
||||
| `ErrorEvent` | Error with `error` message field |
|
||||
| `SearchStartEvent` | Onyx started searching documents |
|
||||
| `CitationEvent` | Source citation with `citation_number` and `document_id` |
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
1
.github/workflows/pr-integration-tests.yml
vendored
1
.github/workflows/pr-integration-tests.yml
vendored
@@ -335,6 +335,7 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
||||
108
.github/workflows/pr-playwright-tests.yml
vendored
108
.github/workflows/pr-playwright-tests.yml
vendored
@@ -268,11 +268,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -280,7 +279,6 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -592,108 +590,6 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -790,7 +686,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
needs: [playwright-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
@@ -119,11 +119,10 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -40,7 +40,19 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery",
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -241,6 +253,35 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -485,6 +526,21 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
31
AGENTS.md
31
AGENTS.md
@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
|
||||
15
backend/ee/onyx/background/celery/apps/background.py
Normal file
15
backend/ee/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -472,9 +471,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -777,7 +774,8 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -223,15 +223,6 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -239,7 +230,6 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -365,7 +365,6 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -156,8 +153,3 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
142
backend/onyx/background/celery/apps/background.py
Normal file
142
backend/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
doc_ids: set[str]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: dict[str, str | None] = {}
|
||||
ids: set[str] = set()
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids[failed_id] = None
|
||||
ids.add(failed_id)
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
doc_ids=all_connector_doc_ids,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
23
backend/onyx/background/celery/configs/background.py
Normal file
23
backend/onyx/background/celery/configs/background.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
def _resolve_and_update_document_parents(
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_id_to_parent: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
|
||||
each document and bulk-update the DB. Mirrors the resolution logic in
|
||||
run_docfetching.py."""
|
||||
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
|
||||
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent_id in raw_id_to_parent.items():
|
||||
if raw_parent_id is None:
|
||||
continue
|
||||
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
|
||||
resolved[doc_id] = node_id if found else source_node_id
|
||||
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
|
||||
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -83,6 +84,28 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -663,7 +686,12 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
@@ -495,7 +495,14 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
@@ -84,6 +84,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
|
||||
@@ -943,9 +943,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
|
||||
page
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -995,7 +992,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=page_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -781,5 +781,4 @@ def build_slim_document(
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
|
||||
)
|
||||
|
||||
@@ -902,11 +902,6 @@ class JiraConnector(
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -385,7 +385,6 @@ class IndexingDocument(Document):
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
external_access: ExternalAccess | None = None
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
|
||||
@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
|
||||
drive_name: str,
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=driveitem.id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sitepage_to_slim_document(
|
||||
site_page: dict[str, Any],
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1600,22 +1594,12 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
|
||||
parent_hierarchy_url: str | None = None
|
||||
if drive_web_url:
|
||||
parent_hierarchy_url = self._get_parent_hierarchy_url(
|
||||
site_url, drive_web_url, drive_name, driveitem
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem,
|
||||
drive_name,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1635,10 +1619,7 @@ class SharepointConnector(
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
site_page, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
|
||||
@@ -565,7 +565,6 @@ def _get_all_doc_ids(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -247,7 +246,6 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
@@ -338,8 +336,7 @@ def update_document_set(
|
||||
)
|
||||
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""CRUD operations for HierarchyNode."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
|
||||
return {doc_id: parent_id for doc_id, parent_id in results}
|
||||
|
||||
|
||||
def update_document_parent_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
doc_parent_map: dict[str, int | None],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
|
||||
|
||||
Only updates rows whose current value differs from the desired value to
|
||||
avoid unnecessary writes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
|
||||
commit: Whether to commit the transaction
|
||||
|
||||
Returns:
|
||||
Number of documents actually updated
|
||||
"""
|
||||
if not doc_parent_map:
|
||||
return 0
|
||||
|
||||
doc_ids = list(doc_parent_map.keys())
|
||||
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
|
||||
|
||||
by_parent: dict[int | None, list[str]] = defaultdict(list)
|
||||
for doc_id, desired_parent_id in doc_parent_map.items():
|
||||
current = existing.get(doc_id)
|
||||
if current == desired_parent_id or doc_id not in existing:
|
||||
continue
|
||||
by_parent[desired_parent_id].append(doc_id)
|
||||
|
||||
updated = 0
|
||||
for desired_parent_id, ids in by_parent.items():
|
||||
db_session.query(Document).filter(Document.id.in_(ids)).update(
|
||||
{Document.parent_hierarchy_node_id: desired_parent_id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
updated += len(ids)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif updated:
|
||||
db_session.flush()
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def update_hierarchy_node_permissions(
|
||||
db_session: Session,
|
||||
raw_node_id: str,
|
||||
|
||||
@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified; DB is not in a valid state.")
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
return latest_settings
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
|
||||
@@ -26,10 +26,11 @@ def get_default_document_index(
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated. Updates are applied to both indices.
|
||||
WARNING: In that case, get_all_document_indices should be used.
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
@@ -50,26 +51,11 @@ def get_default_document_index(
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -100,7 +86,8 @@ def get_all_document_indices(
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks are not currently supported so we hardcode appropriate values.
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
@@ -136,36 +123,13 @@ def get_all_document_indices(
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
secondary_embedding_dim: int | None,
|
||||
secondary_embedding_precision: EmbeddingPrecision | None,
|
||||
# NOTE: We do not support large chunks right now.
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
multitenant: bool = False,
|
||||
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
f"Expected {MULTI_TENANT}, got {multitenant}."
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
self._secondary_real_index: OpenSearchDocumentIndex | None = None
|
||||
if self.secondary_index_name:
|
||||
if secondary_embedding_dim is None or secondary_embedding_precision is None:
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
self._secondary_real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=self.secondary_index_name,
|
||||
embedding_dim=secondary_embedding_dim,
|
||||
embedding_precision=secondary_embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
) -> None:
|
||||
self._real_index.verify_and_create_index_if_necessary(
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
if self.secondary_index_name:
|
||||
if (
|
||||
secondary_index_embedding_dim is None
|
||||
or secondary_index_embedding_precision is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.verify_and_create_index_if_necessary(
|
||||
secondary_index_embedding_dim, secondary_index_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
total_chunks_deleted += self._secondary_real_index.delete(
|
||||
doc_id, chunk_count
|
||||
)
|
||||
return total_chunks_deleted
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
f"Tried to update document {doc_id} with no updated fields or user fields."
|
||||
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
try:
|
||||
self._real_index.update([update_request])
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.update([update_request])
|
||||
except NotFoundError:
|
||||
logger.exception(
|
||||
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
|
||||
|
||||
@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
total_chunks_deleted += vespa_document_index.delete(
|
||||
document_id=doc_id, chunk_count=chunk_count
|
||||
)
|
||||
|
||||
return total_chunks_deleted
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -92,98 +92,6 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _normalize_content(raw: Any) -> str:
|
||||
"""Normalize a message content field to a plain string.
|
||||
|
||||
Content can be a string, None, or a list of content-block dicts
|
||||
(e.g. [{"type": "text", "text": "..."}]).
|
||||
"""
|
||||
if raw is None:
|
||||
return ""
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
return "\n".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in raw
|
||||
)
|
||||
return str(raw)
|
||||
|
||||
|
||||
def _strip_tool_content_from_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tool-related messages to plain text.
|
||||
|
||||
Bedrock's Converse API requires toolConfig when messages contain
|
||||
toolUse/toolResult content blocks. When no tools are provided for the
|
||||
current request, we must convert any tool-related history into plain text
|
||||
to avoid the "toolConfig field must be defined" error.
|
||||
|
||||
This is the same approach used by _OllamaHistoryMessageFormatter.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
# Convert structured tool calls to text representation
|
||||
tool_call_lines = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
args = func.get("arguments", "{}")
|
||||
tc_id = tc.get("id", "")
|
||||
tool_call_lines.append(
|
||||
f"[Tool Call] name={name} id={tc_id} args={args}"
|
||||
)
|
||||
|
||||
existing_content = _normalize_content(msg.get("content"))
|
||||
parts = (
|
||||
[existing_content] + tool_call_lines
|
||||
if existing_content
|
||||
else tool_call_lines
|
||||
)
|
||||
new_msg = {
|
||||
"role": "assistant",
|
||||
"content": "\n".join(parts),
|
||||
}
|
||||
result.append(new_msg)
|
||||
|
||||
elif role == "tool":
|
||||
# Convert tool response to user message with text content
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = _normalize_content(msg.get("content"))
|
||||
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
|
||||
# Merge into previous user message if it is also a converted
|
||||
# tool result to avoid consecutive user messages (Bedrock requires
|
||||
# strict user/assistant alternation).
|
||||
if (
|
||||
result
|
||||
and result[-1]["role"] == "user"
|
||||
and "[Tool Result]" in result[-1].get("content", "")
|
||||
):
|
||||
result[-1]["content"] += "\n\n" + tool_result_text
|
||||
else:
|
||||
result.append({"role": "user", "content": tool_result_text})
|
||||
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -496,30 +404,13 @@ class LitellmLLM(LLM):
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
messages = _prompt_to_dicts(prompt)
|
||||
|
||||
# Bedrock's Converse API requires toolConfig when messages
|
||||
# contain toolUse/toolResult content blocks. When no tools are
|
||||
# provided for this request but the history contains tool
|
||||
# content from previous turns, strip it to plain text.
|
||||
is_bedrock = self._model_provider in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}
|
||||
if (
|
||||
is_bedrock
|
||||
and not tools
|
||||
and _messages_contain_tool_content(messages)
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=messages,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
|
||||
@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
@@ -146,11 +146,6 @@ class SlackRenderer(HTMLRenderer):
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
@@ -223,48 +218,5 @@ class SlackRenderer(HTMLRenderer):
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"version": "4.11.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
|
||||
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import delete_document_set as db_delete_document_set
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import insert_document_set
|
||||
@@ -143,10 +142,7 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
db_session.refresh(document_set)
|
||||
db_delete_document_set(document_set, db_session)
|
||||
else:
|
||||
if not DISABLE_VECTOR_DB:
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -10,8 +11,6 @@ from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.search_settings import get_current_db_embedding_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.indexing.models import EmbeddingModelDetail
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -60,7 +59,7 @@ def test_embedding_configuration(
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||
@@ -94,9 +93,8 @@ def delete_embedding_provider(
|
||||
embedding_provider is not None
|
||||
and provider_type == embedding_provider.provider_type
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"You can't delete a currently active model",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, provider_type=provider_type)
|
||||
|
||||
@@ -11,6 +11,7 @@ from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -37,8 +38,6 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import user_can_access_persona
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
@@ -187,7 +186,7 @@ def _validate_llm_provider_change(
|
||||
Only enforced in MULTI_TENANT mode.
|
||||
|
||||
Raises:
|
||||
OnyxError: If api_base or custom_config changed without changing API key
|
||||
HTTPException: If api_base or custom_config changed without changing API key
|
||||
"""
|
||||
if not MULTI_TENANT or api_key_changed:
|
||||
return
|
||||
@@ -201,9 +200,9 @@ def _validate_llm_provider_change(
|
||||
)
|
||||
|
||||
if api_base_changed or custom_config_changed:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base and/or custom config cannot be changed without changing the API key",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base and/or custom config cannot be changed without changing the API key",
|
||||
)
|
||||
|
||||
|
||||
@@ -223,7 +222,7 @@ def fetch_llm_provider_options(
|
||||
for well_known_llm in well_known_llms:
|
||||
if well_known_llm.name == provider_name:
|
||||
return well_known_llm
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
@@ -282,7 +281,7 @@ def test_llm_configuration(
|
||||
error_msg = test_llm(llm)
|
||||
|
||||
if error_msg:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -293,11 +292,11 @@ def test_default_provider(
|
||||
llm = get_default_llm()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
|
||||
raise HTTPException(status_code=400, detail=str(error))
|
||||
|
||||
|
||||
@admin_router.get("/provider")
|
||||
@@ -363,31 +362,35 @@ def put_llm_provider(
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Renaming providers is not currently supported",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -412,9 +415,9 @@ def put_llm_provider(
|
||||
db_session, persona_ids
|
||||
)
|
||||
if missing_personas:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
|
||||
)
|
||||
# Remove duplicates while preserving order
|
||||
seen: set[int] = set()
|
||||
@@ -470,7 +473,7 @@ def put_llm_provider(
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
@@ -480,19 +483,19 @@ def delete_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
try:
|
||||
if not force:
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if model and model.llm_provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the default LLM provider",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@@ -532,9 +535,9 @@ def get_auto_config(
|
||||
"""
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Failed to fetch configuration from GitHub",
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch configuration from GitHub",
|
||||
)
|
||||
return config.model_dump()
|
||||
|
||||
@@ -691,13 +694,13 @@ def list_llm_providers_for_persona(
|
||||
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
# Verify user has access to this persona
|
||||
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You don't have access to this assistant",
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have access to this assistant",
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
@@ -851,9 +854,9 @@ def get_bedrock_available_models(
|
||||
try:
|
||||
bedrock = session.client("bedrock")
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
|
||||
)
|
||||
|
||||
# Build model info dict from foundation models (modelId -> metadata)
|
||||
@@ -972,14 +975,14 @@ def get_bedrock_available_models(
|
||||
return results
|
||||
|
||||
except (ClientError, NoCredentialsError, BotoCoreError) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Failed to connect to AWS Bedrock: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to connect to AWS Bedrock: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"Unexpected error fetching Bedrock models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unexpected error fetching Bedrock models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -991,9 +994,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch Ollama models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
@@ -1010,9 +1013,9 @@ def get_ollama_available_models(
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API base URL is required to fetch Ollama models.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API base URL is required to fetch Ollama models.",
|
||||
)
|
||||
|
||||
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
|
||||
@@ -1021,9 +1024,9 @@ def get_ollama_available_models(
|
||||
# with the same response format
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Ollama server",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
@@ -1125,9 +1128,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch OpenRouter models: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch OpenRouter models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@@ -1148,9 +1151,9 @@ def get_openrouter_available_models(
|
||||
|
||||
data = response_json.get("data", [])
|
||||
if not isinstance(data, list) or len(data) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenRouter endpoint",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your OpenRouter endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenRouterFinalModelResponse] = []
|
||||
@@ -1185,9 +1188,8 @@ def get_openrouter_available_models(
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenRouter",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No compatible models found from OpenRouter"
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
@@ -6,11 +6,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
@@ -18,25 +15,20 @@ from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import create_search_settings
|
||||
from onyx.db.search_settings import delete_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from onyx.server.manage.models import FullModelVersionResponse
|
||||
from onyx.server.models import IdReturn
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
@@ -49,99 +41,110 @@ def set_new_search_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> IdReturn:
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
Creates a new SearchSettings row and cancels the previous secondary indexing
|
||||
if any exists.
|
||||
"""
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
model_name=search_settings_new.contextual_rag_llm_name,
|
||||
db_session=db_session,
|
||||
# TODO(andrei): Re-enable.
|
||||
# NOTE Enable integration external dependency tests in test_search_settings.py
|
||||
# when this is reenabled. They are currently skipped
|
||||
logger.error("Setting new search settings is temporarily disabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Setting new search settings is temporarily disabled.",
|
||||
)
|
||||
# if search_settings_new.index_name:
|
||||
# logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# # Disallow contextual RAG for cloud deployments
|
||||
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Contextual RAG disabled in Onyx Cloud",
|
||||
# )
|
||||
|
||||
if search_settings_new.index_name is None:
|
||||
# We define index name here.
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
search_settings_new.model_name == search_settings.model_name
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.model_dump()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**search_settings_new.model_dump()
|
||||
)
|
||||
# # Validate cloud provider exists or create new LiteLLM provider
|
||||
# if search_settings_new.provider_type is not None:
|
||||
# cloud_provider = get_embedding_provider_from_provider_type(
|
||||
# db_session, provider_type=search_settings_new.provider_type
|
||||
# )
|
||||
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if cloud_provider is None:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
# )
|
||||
|
||||
if secondary_search_settings:
|
||||
# Cancel any background indexing jobs.
|
||||
expire_index_attempts(
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
# validate_contextual_rag_model(
|
||||
# provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
# model_name=search_settings_new.contextual_rag_llm_name,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# Mark previous model as a past model directly.
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
|
||||
new_search_settings = create_search_settings(
|
||||
search_settings=new_search_settings_request, db_session=db_session
|
||||
)
|
||||
# if search_settings_new.index_name is None:
|
||||
# # We define index name here
|
||||
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
# if (
|
||||
# search_settings_new.model_name == search_settings.model_name
|
||||
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
# ):
|
||||
# index_name += ALT_INDEX_SUFFIX
|
||||
# search_values = search_settings_new.model_dump()
|
||||
# search_values["index_name"] = index_name
|
||||
# new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
# else:
|
||||
# new_search_settings_request = SavedSearchSettings(
|
||||
# **search_settings_new.model_dump()
|
||||
# )
|
||||
|
||||
# Ensure the document indices have the new index immediately.
|
||||
document_indices = get_all_document_indices(search_settings, new_search_settings)
|
||||
for document_index in document_indices:
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=search_settings.embedding_precision,
|
||||
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
)
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Pause index attempts for the currently in-use index to preserve resources.
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=new_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
# if secondary_search_settings:
|
||||
# # Cancel any background indexing jobs
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
# # Mark previous model as a past model directly
|
||||
# update_search_settings_status(
|
||||
# search_settings=secondary_search_settings,
|
||||
# new_status=IndexModelStatus.PAST,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# new_search_settings = create_search_settings(
|
||||
# search_settings=new_search_settings_request, db_session=db_session
|
||||
# )
|
||||
|
||||
# # Ensure Vespa has the new index immediately
|
||||
# get_multipass_config(search_settings)
|
||||
# get_multipass_config(new_search_settings)
|
||||
# document_index = get_default_document_index(
|
||||
# search_settings, new_search_settings, db_session
|
||||
# )
|
||||
|
||||
# document_index.ensure_indices_exist(
|
||||
# primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
# primary_embedding_precision=search_settings.embedding_precision,
|
||||
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
# )
|
||||
|
||||
# # Pause index attempts for the currently in use index to preserve resources
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=search_settings.id, db_session=db_session
|
||||
# )
|
||||
# for cc_pair in get_connector_credential_pairs(db_session):
|
||||
# resync_cc_pair(
|
||||
# cc_pair=cc_pair,
|
||||
# search_settings_id=new_search_settings.id,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# db_session.commit()
|
||||
# return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
|
||||
|
||||
@@ -60,11 +60,9 @@ class Settings(BaseModel):
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -86,19 +84,6 @@ class CodeInterpreterClient:
|
||||
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self) -> CodeInterpreterClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self.session.close()
|
||||
self._closed = True
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
@@ -118,13 +103,7 @@ class CodeInterpreterClient:
|
||||
return payload
|
||||
|
||||
def health(self, use_cache: bool = False) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy
|
||||
|
||||
Args:
|
||||
use_cache: When True, return a cached result if available and
|
||||
within the TTL window. The cache is always populated
|
||||
after a live request regardless of this flag.
|
||||
"""
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
if use_cache:
|
||||
cached = _health_cache.get(self.base_url)
|
||||
if cached is not None:
|
||||
@@ -192,11 +171,8 @@ class CodeInterpreterClient:
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
finally:
|
||||
response.close()
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
|
||||
@@ -111,8 +111,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
with CodeInterpreterClient() as client:
|
||||
return client.health(use_cache=True)
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -176,203 +176,196 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
# Create Code Interpreter client — context manager ensures
|
||||
# session.close() is called on every exit path.
|
||||
with CodeInterpreterClient() as client:
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
# Create Code Interpreter client
|
||||
client = CodeInterpreterClient()
|
||||
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=(
|
||||
event.data if event.stream == "stdout" else ""
|
||||
),
|
||||
stderr=(
|
||||
event.data if event.stream == "stderr" else ""
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file "
|
||||
f"{workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated "
|
||||
f"file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged "
|
||||
f"file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=(None if result_event.exit_code == 0 else truncated_stderr),
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
# Emit error delta
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file {workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.3
|
||||
nltk==3.9.1
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
# via
|
||||
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.3
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -16,6 +16,10 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
|
||||
|
||||
def run_jobs() -> None:
|
||||
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
|
||||
use_lightweight = True
|
||||
|
||||
# command setup
|
||||
cmd_worker_primary = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -70,48 +74,6 @@ def run_jobs() -> None:
|
||||
"--queues=connector_doc_fetching",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
]
|
||||
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,user_file_delete",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -120,31 +82,144 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
|
||||
all_workers = [
|
||||
("PRIMARY", cmd_worker_primary),
|
||||
("LIGHT", cmd_worker_light),
|
||||
("DOCPROCESSING", cmd_worker_docprocessing),
|
||||
("DOCFETCHING", cmd_worker_docfetching),
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
("BEAT", cmd_beat),
|
||||
]
|
||||
# Prepare background worker commands based on mode
|
||||
if use_lightweight:
|
||||
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
|
||||
cmd_worker_background = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
|
||||
]
|
||||
background_workers = [("BACKGROUND", cmd_worker_background)]
|
||||
else:
|
||||
print("Starting workers in STANDARD mode (separate background workers)")
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,sandbox",
|
||||
]
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
|
||||
]
|
||||
background_workers = [
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
]
|
||||
|
||||
processes = []
|
||||
for name, cmd in all_workers:
|
||||
# spawn processes
|
||||
worker_primary_process = subprocess.Popen(
|
||||
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_light_process = subprocess.Popen(
|
||||
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_docprocessing_process = subprocess.Popen(
|
||||
cmd_worker_docprocessing,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
# Spawn background worker processes based on mode
|
||||
background_processes = []
|
||||
for name, cmd in background_workers:
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
processes.append((name, process))
|
||||
background_processes.append((name, process))
|
||||
|
||||
threads = []
|
||||
for name, process in processes:
|
||||
# monitor threads
|
||||
worker_primary_thread = threading.Thread(
|
||||
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||
)
|
||||
worker_light_thread = threading.Thread(
|
||||
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||
)
|
||||
worker_docprocessing_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
# Create monitor threads for background workers
|
||||
background_threads = []
|
||||
for name, process in background_processes:
|
||||
thread = threading.Thread(target=monitor_process, args=(name, process))
|
||||
threads.append(thread)
|
||||
background_threads.append(thread)
|
||||
|
||||
# Start all threads
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_docprocessing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
for thread in background_threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
# Wait for all threads
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_docprocessing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
for thread in background_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
|
||||
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
|
||||
|
||||
stop_and_remove_containers() {
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
stop_and_remove_containers
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -22,26 +12,16 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
|
||||
|
||||
# Usage of the script with optional volume arguments
|
||||
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
|
||||
# [minio_volume] [--keep-opensearch-data]
|
||||
|
||||
KEEP_OPENSEARCH_DATA=false
|
||||
POSITIONAL_ARGS=()
|
||||
for arg in "$@"; do
|
||||
if [[ "$arg" == "--keep-opensearch-data" ]]; then
|
||||
KEEP_OPENSEARCH_DATA=true
|
||||
else
|
||||
POSITIONAL_ARGS+=("$arg")
|
||||
fi
|
||||
done
|
||||
|
||||
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
|
||||
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
|
||||
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
|
||||
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
|
||||
VESPA_VOLUME=${1:-""} # Default is empty if not provided
|
||||
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
|
||||
REDIS_VOLUME=${3:-""} # Default is empty if not provided
|
||||
MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
stop_and_remove_containers
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -59,29 +39,6 @@ else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
|
||||
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
|
||||
# .vscode/.env so existing dev setups that stored it there aren't silently
|
||||
# broken.
|
||||
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
|
||||
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
|
||||
set -a
|
||||
# shellcheck source=/dev/null
|
||||
source "$VSCODE_ENV"
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Start the OpenSearch container using the same service from docker-compose that
|
||||
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
|
||||
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
|
||||
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
|
||||
# restarts, else the volume is deleted so the container starts fresh.
|
||||
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
|
||||
echo "Deleting opensearch-data volume..."
|
||||
docker volume rm onyx_opensearch-data 2>/dev/null || true
|
||||
fi
|
||||
echo "Starting OpenSearch container..."
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
|
||||
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
@@ -103,6 +60,7 @@ echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$PARENT_DIR"
|
||||
|
||||
|
||||
10
backend/scripts/restart_opensearch_container.sh
Normal file
10
backend/scripts/restart_opensearch_container.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
|
||||
source "$(dirname "$0")/../../.vscode/.env"
|
||||
|
||||
cd "$(dirname "$0")/../../deployment/docker_compose"
|
||||
|
||||
# Start OpenSearch.
|
||||
echo "Forcefully starting fresh OpenSearch container..."
|
||||
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch
|
||||
@@ -1,5 +1,23 @@
|
||||
#!/bin/sh
|
||||
# Entrypoint script for supervisord
|
||||
# Entrypoint script for supervisord that sets environment variables
|
||||
# for controlling which celery workers to start
|
||||
|
||||
# Default to lightweight mode if not set
|
||||
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
|
||||
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
|
||||
fi
|
||||
|
||||
# Set the complementary variable for supervisord
|
||||
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
|
||||
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="false"
|
||||
else
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="true"
|
||||
fi
|
||||
|
||||
echo "Worker mode configuration:"
|
||||
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
|
||||
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
|
||||
|
||||
# Launch supervisord with environment variables available
|
||||
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
@@ -39,6 +39,7 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Standard mode: Light worker for fast operations
|
||||
# NOTE: only allowing configuration here and not in the other celery workers,
|
||||
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
|
||||
# user group syncing, deletion, etc.)
|
||||
@@ -53,7 +54,26 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Lightweight mode: single consolidated background worker
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
|
||||
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
|
||||
[program:celery_worker_background]
|
||||
command=celery -A onyx.background.celery.versioned_apps.background worker
|
||||
--loglevel=INFO
|
||||
--hostname=background@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
|
||||
stdout_logfile=/var/log/celery_worker_background.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
|
||||
|
||||
# Standard mode: separate workers for different background tasks
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
[program:celery_worker_heavy]
|
||||
command=celery -A onyx.background.celery.versioned_apps.heavy worker
|
||||
--loglevel=INFO
|
||||
@@ -65,7 +85,9 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document processing worker
|
||||
[program:celery_worker_docprocessing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
|
||||
--loglevel=INFO
|
||||
@@ -77,6 +99,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_user_file_processing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
|
||||
@@ -89,7 +112,9 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document fetching worker
|
||||
[program:celery_worker_docfetching]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
@@ -101,6 +126,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
@@ -113,6 +139,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
@@ -170,6 +197,7 @@ command=tail -qF
|
||||
/var/log/celery_beat.log
|
||||
/var/log/celery_worker_primary.log
|
||||
/var/log/celery_worker_light.log
|
||||
/var/log/celery_worker_background.log
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/celery_worker_docprocessing.log
|
||||
/var/log/celery_worker_monitoring.log
|
||||
|
||||
@@ -5,8 +5,6 @@ Verifies that:
|
||||
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
|
||||
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
|
||||
3. Upserting is idempotent (running twice doesn't duplicate nodes)
|
||||
4. Document-to-hierarchy-node linkage is updated during pruning
|
||||
5. link_hierarchy_nodes_to_documents links nodes that are also documents
|
||||
|
||||
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
|
||||
combined with a real PostgreSQL database for verifying persistence.
|
||||
@@ -29,13 +27,9 @@ from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.hierarchy import ensure_source_node_exists
|
||||
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
|
||||
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -95,18 +89,8 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
|
||||
]
|
||||
|
||||
|
||||
DOC_PARENT_MAP = {
|
||||
"msg-001": CHANNEL_A_ID,
|
||||
"msg-002": CHANNEL_A_ID,
|
||||
"msg-003": CHANNEL_B_ID,
|
||||
}
|
||||
|
||||
|
||||
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
|
||||
return [
|
||||
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
|
||||
for doc_id in SLIM_DOC_IDS
|
||||
]
|
||||
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
|
||||
|
||||
|
||||
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
@@ -142,31 +126,14 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cleanup_test_data(db_session: Session) -> None:
|
||||
"""Remove all test hierarchy nodes and documents to isolate tests."""
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
|
||||
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == TEST_SOURCE
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_test_documents(db_session: Session) -> list[DbDocument]:
|
||||
"""Insert minimal Document rows for our test doc IDs."""
|
||||
docs = []
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
doc = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id=doc_id,
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
docs.append(doc)
|
||||
db_session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -180,14 +147,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
|
||||
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
|
||||
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
|
||||
expected_ids = {
|
||||
CHANNEL_A_ID,
|
||||
CHANNEL_B_ID,
|
||||
CHANNEL_C_ID,
|
||||
*SLIM_DOC_IDS,
|
||||
}
|
||||
assert result.raw_id_to_parent.keys() == expected_ids
|
||||
assert result.doc_ids == expected_ids
|
||||
|
||||
# Hierarchy nodes should be the 3 channels
|
||||
assert len(result.hierarchy_nodes) == 3
|
||||
@@ -198,7 +165,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
|
||||
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
|
||||
then verify the DB state (node count, parent relationships, permissions)."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
@@ -263,7 +230,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
) -> None:
|
||||
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
|
||||
marked is_public=True regardless of their external_access settings."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -290,7 +257,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
"""Upserting the same hierarchy nodes twice must not create duplicates.
|
||||
The second call should update existing rows in place."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -328,7 +295,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
|
||||
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
|
||||
"""Upserting a hierarchy node with changed fields should update the existing row."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -375,193 +342,3 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
|
||||
assert db_node.is_public is True
|
||||
assert db_node.external_user_emails is not None
|
||||
assert set(db_node.external_user_emails) == {"new_user@example.com"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document-to-hierarchy-node linkage tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_extraction_preserves_parent_hierarchy_raw_node_id(
|
||||
db_session: Session, # noqa: ARG001
|
||||
) -> None:
|
||||
"""extract_ids_from_runnable_connector should carry the
|
||||
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
|
||||
connector = MockSlimConnectorWithPermSync()
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
for doc_id, expected_parent in DOC_PARENT_MAP.items():
|
||||
assert (
|
||||
result.raw_id_to_parent[doc_id] == expected_parent
|
||||
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
|
||||
|
||||
# Hierarchy node entries have None parent (they aren't documents)
|
||||
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
|
||||
assert result.raw_id_to_parent[channel_id] is None
|
||||
|
||||
|
||||
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""update_document_parent_hierarchy_nodes should set
|
||||
Document.parent_hierarchy_node_id for each document in the mapping."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
|
||||
# Create documents with no parent set
|
||||
docs = _create_test_documents(db_session)
|
||||
for doc in docs:
|
||||
assert doc.parent_hierarchy_node_id is None
|
||||
|
||||
# Build resolved map (same logic as _resolve_and_update_document_parents)
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
|
||||
|
||||
updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert updated == len(SLIM_DOC_IDS)
|
||||
|
||||
# Verify each document now points to the correct hierarchy node
|
||||
db_session.expire_all()
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
tmp_doc = db_session.get(DbDocument, doc_id)
|
||||
assert tmp_doc is not None
|
||||
doc = tmp_doc
|
||||
expected_node_id = node_id_by_raw[raw_parent]
|
||||
assert (
|
||||
doc.parent_hierarchy_node_id == expected_node_id
|
||||
), f"Document {doc_id} should point to node for {raw_parent}"
|
||||
|
||||
|
||||
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
|
||||
"""Running update_document_parent_hierarchy_nodes a second time with the
|
||||
same mapping should update zero rows."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
_create_test_documents(db_session)
|
||||
|
||||
resolved: dict[str, int | None] = {
|
||||
doc_id: node_id_by_raw[raw_parent]
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items()
|
||||
}
|
||||
|
||||
first_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert first_updated == len(SLIM_DOC_IDS)
|
||||
|
||||
second_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert second_updated == 0
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_to_documents_for_confluence(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
|
||||
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
|
||||
when a hierarchy node's raw_node_id matches a document ID."""
|
||||
_cleanup_test_data(db_session)
|
||||
confluence_source = DocumentSource.CONFLUENCE
|
||||
|
||||
# Clean up any existing Confluence hierarchy nodes
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
ensure_source_node_exists(db_session, confluence_source, commit=True)
|
||||
|
||||
# Create a hierarchy node whose raw_node_id matches a document ID
|
||||
page_node_id = "confluence-page-123"
|
||||
nodes = [
|
||||
PydanticHierarchyNode(
|
||||
raw_node_id=page_node_id,
|
||||
raw_parent_id=None,
|
||||
display_name="Test Page",
|
||||
link="https://wiki.example.com/page/123",
|
||||
node_type=HierarchyNodeType.PAGE,
|
||||
),
|
||||
]
|
||||
upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=nodes,
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
|
||||
# Verify the node exists but has no document_id yet
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id is None
|
||||
|
||||
# Create a document with the same ID as the hierarchy node
|
||||
doc = DbDocument(
|
||||
id=page_node_id,
|
||||
semantic_id="Test Page",
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
|
||||
# Link nodes to documents
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=[page_node_id],
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
)
|
||||
assert linked == 1
|
||||
|
||||
# Verify the hierarchy node now has document_id set
|
||||
db_session.expire_all()
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id == page_node_id
|
||||
|
||||
# Cleanup
|
||||
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""link_hierarchy_nodes_to_documents should return 0 for sources that
|
||||
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=SLIM_DOC_IDS,
|
||||
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
|
||||
commit=False,
|
||||
)
|
||||
assert linked == 0
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
@@ -19,8 +20,6 @@ from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.server.manage.llm.api import (
|
||||
@@ -123,16 +122,16 @@ class TestLLMConfigurationEndpoint:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
def test_failed_llm_test_raises_onyx_error(
|
||||
def test_failed_llm_test_raises_http_exception(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str, # noqa: ARG002
|
||||
) -> None:
|
||||
"""
|
||||
Test that a failed LLM test raises an OnyxError with VALIDATION_ERROR.
|
||||
Test that a failed LLM test raises an HTTPException with status 400.
|
||||
|
||||
When test_llm returns an error message, the endpoint should raise
|
||||
an OnyxError with the error details.
|
||||
an HTTPException with the error details.
|
||||
"""
|
||||
error_message = "Invalid API key: Authentication failed"
|
||||
|
||||
@@ -144,7 +143,7 @@ class TestLLMConfigurationEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -157,8 +156,9 @@ class TestLLMConfigurationEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
# Verify the exception details
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -536,11 +536,11 @@ class TestDefaultProviderEndpoint:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
# Now run_test_default_provider should fail
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert "No LLM Provider setup" in exc_info.value.message
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "No LLM Provider setup" in exc_info.value.detail
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -581,11 +581,11 @@ class TestDefaultProviderEndpoint:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_default_provider(_=_create_mock_admin())
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.message == error_message
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == error_message
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -16,14 +16,13 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.api import _mask_string
|
||||
from onyx.server.manage.llm.api import put_llm_provider
|
||||
@@ -101,7 +100,7 @@ class TestLLMProviderChanges:
|
||||
api_base="https://attacker.example.com",
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -109,9 +108,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -237,7 +236,7 @@ class TestLLMProviderChanges:
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -245,9 +244,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -340,7 +339,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -348,9 +347,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -376,7 +375,7 @@ class TestLLMProviderChanges:
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=update_request,
|
||||
is_creation=False,
|
||||
@@ -384,9 +383,9 @@ class TestLLMProviderChanges:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "cannot be changed without changing the API key" in str(
|
||||
exc_info.value.message
|
||||
exc_info.value.detail
|
||||
)
|
||||
finally:
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_default_contextual_rag_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -38,8 +37,6 @@ def _create_llm_provider_and_model(
|
||||
model_name: str,
|
||||
) -> None:
|
||||
"""Insert an LLM provider with a single visible model configuration."""
|
||||
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
|
||||
return
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
@@ -149,8 +146,8 @@ def baseline_search_settings(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -158,7 +155,6 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -200,8 +196,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -209,7 +205,6 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -271,7 +266,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -279,7 +274,6 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
|
||||
@@ -427,7 +427,7 @@ def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_response.status_code == 400
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
|
||||
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
|
||||
|
||||
# Verify provider still exists
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
@@ -673,8 +673,8 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
|
||||
headers=admin_user.headers,
|
||||
json=base_payload,
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
@@ -711,7 +711,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "not currently supported" in response.json()["message"]
|
||||
assert "not currently supported" in response.json()["detail"]
|
||||
|
||||
# Verify no duplicate was created — only the original provider should exist
|
||||
provider = _get_provider_by_id(admin_user, provider_id)
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
|
||||
|
||||
# Should return 403 Forbidden
|
||||
assert response.status_code == 403
|
||||
assert "don't have access to this assistant" in response.json()["message"]
|
||||
assert "don't have access to this assistant" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_authorized_persona_access_returns_filtered_providers(
|
||||
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
|
||||
|
||||
# Should return 404
|
||||
assert response.status_code == 404
|
||||
assert "Persona not found" in response.json()["message"]
|
||||
assert "Persona not found" in response.json()["detail"]
|
||||
|
||||
@@ -42,78 +42,6 @@ class NightlyProviderConfig(BaseModel):
|
||||
strict: bool
|
||||
|
||||
|
||||
def _stringify_custom_config_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config: dict[object, object],
|
||||
) -> bool:
|
||||
normalized_keys = {str(key).strip().lower() for key in raw_custom_config}
|
||||
provider_specific_keys = {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}
|
||||
if normalized_keys & provider_specific_keys:
|
||||
return False
|
||||
|
||||
normalized_type = str(raw_custom_config.get("type", "")).strip().lower()
|
||||
if normalized_type not in {"service_account", "external_account"}:
|
||||
return False
|
||||
|
||||
# Service account JSON usually includes private_key/client_email, while external
|
||||
# account JSON includes credential_source. Either shape should be accepted.
|
||||
has_service_account_markers = any(
|
||||
key in normalized_keys for key in {"private_key", "client_email"}
|
||||
)
|
||||
has_external_account_markers = "credential_source" in normalized_keys
|
||||
return has_service_account_markers or has_external_account_markers
|
||||
|
||||
|
||||
def _normalize_custom_config(
|
||||
provider: str, raw_custom_config: dict[object, object]
|
||||
) -> dict[str, str]:
|
||||
if provider == "vertex_ai" and _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config
|
||||
):
|
||||
return {"vertex_credentials": json.dumps(raw_custom_config)}
|
||||
|
||||
normalized: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw_custom_config.items():
|
||||
key = str(raw_key).strip()
|
||||
key_lower = key.lower()
|
||||
|
||||
if provider == "vertex_ai":
|
||||
if key_lower in {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
}:
|
||||
key = "vertex_credentials"
|
||||
elif key_lower in {
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}:
|
||||
key = "vertex_location"
|
||||
|
||||
normalized[key] = _stringify_custom_config_value(raw_value)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
@@ -152,9 +80,7 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = _normalize_custom_config(
|
||||
provider=provider, raw_custom_config=parsed
|
||||
)
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
@@ -222,23 +148,6 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "vertex_ai":
|
||||
has_vertex_credentials = bool(
|
||||
config.custom_config and config.custom_config.get("vertex_credentials")
|
||||
)
|
||||
if not has_vertex_credentials:
|
||||
configured_keys = (
|
||||
sorted(config.custom_config.keys()) if config.custom_config else []
|
||||
)
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' "
|
||||
f"for provider '{config.provider}'. "
|
||||
f"Found keys: {configured_keys}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -284,7 +193,6 @@ def _create_provider_payload(
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
@@ -300,23 +208,24 @@ def _create_provider_payload(
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(
|
||||
provider_id: int, model_name: str, admin_user: DATestUser
|
||||
) -> None:
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
default_text = list_response.json().get("default_text")
|
||||
assert default_text is not None, "Expected a default provider after setting default"
|
||||
assert default_text.get("provider_id") == provider_id, (
|
||||
f"Expected provider {provider_id} to be default, "
|
||||
f"found {default_text.get('provider_id')}"
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
)
|
||||
assert (
|
||||
default_text.get("model_name") == model_name
|
||||
), f"Expected default model {model_name}, found {default_text.get('model_name')}"
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
@@ -417,9 +326,8 @@ def _create_and_test_provider_for_model(
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
headers=admin_user.headers,
|
||||
json={"provider_id": provider_id, "model_name": model_name},
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
@@ -427,9 +335,7 @@ def _create_and_test_provider_for_model(
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(
|
||||
provider_id=provider_id, model_name=model_name, admin_user=admin_user
|
||||
)
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -364,6 +365,7 @@ def test_update_contextual_rag_missing_model_name(
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_with_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -392,6 +394,7 @@ def test_set_new_search_settings_with_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_without_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -416,6 +419,7 @@ def test_set_new_search_settings_without_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_then_update_inference_settings(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -453,6 +457,7 @@ def test_set_new_then_update_inference_settings(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_replaces_previous_secondary(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
|
||||
@@ -281,10 +281,9 @@ class TestApplyLicenseStatusToSettings:
|
||||
}
|
||||
|
||||
|
||||
class TestSettingsDefaults:
|
||||
"""Verify Settings model defaults for CE deployments."""
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
"""Verify the Settings model defaults ee_features_enabled to False."""
|
||||
|
||||
def test_default_ee_features_disabled(self) -> None:
|
||||
"""CE default: ee_features_enabled is False."""
|
||||
settings = Settings()
|
||||
assert settings.ee_features_enabled is False
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
|
||||
from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
@@ -13,11 +14,22 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
class _StubConfig:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.model_provider = model_provider
|
||||
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.config = _StubConfig(model_provider=model_provider)
|
||||
|
||||
|
||||
def create_message(
|
||||
content: str, message_type: MessageType, token_count: int | None = None
|
||||
) -> ChatMessageSimple:
|
||||
@@ -934,6 +946,37 @@ class TestForgottenFileMetadata:
|
||||
assert "moby_dick.txt" in forgotten.message
|
||||
|
||||
|
||||
class TestBedrockToolConfigGuard:
|
||||
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is True
|
||||
|
||||
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_message("Answer", MessageType.ASSISTANT, 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.OPENAI)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
|
||||
class TestFallbackToolExtraction:
|
||||
def _tool_defs(self) -> list[dict]:
|
||||
return [
|
||||
|
||||
@@ -1214,218 +1214,3 @@ def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
|
||||
|
||||
# The env lock context manager should never have been called
|
||||
mock_env_lock.assert_not_called()
|
||||
|
||||
|
||||
# ---- Tests for Bedrock tool content stripping ----
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_role() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "I'll search for that."},
|
||||
{"role": "tool", "content": "search results", "tool_call_id": "tc_1"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_with_tool_calls() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is True
|
||||
|
||||
|
||||
def test_messages_contain_tool_content_without_tools() -> None:
|
||||
from onyx.llm.multi_llm import _messages_contain_tool_content
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert _messages_contain_tool_content(messages) is False
|
||||
|
||||
|
||||
def test_strip_tool_content_converts_assistant_tool_calls_to_text() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Search for cats"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "cats"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Found 3 results about cats.",
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
assert len(result) == 4
|
||||
|
||||
# First message unchanged
|
||||
assert result[0] == {"role": "user", "content": "Search for cats"}
|
||||
|
||||
# Assistant with tool calls → plain text
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert "tool_calls" not in result[1]
|
||||
assert "Let me search." in result[1]["content"]
|
||||
assert "[Tool Call]" in result[1]["content"]
|
||||
assert "search" in result[1]["content"]
|
||||
assert "tc_1" in result[1]["content"]
|
||||
|
||||
# Tool response → user message
|
||||
assert result[2]["role"] == "user"
|
||||
assert "[Tool Result]" in result[2]["content"]
|
||||
assert "tc_1" in result[2]["content"]
|
||||
assert "Found 3 results about cats." in result[2]["content"]
|
||||
|
||||
# Final assistant message unchanged
|
||||
assert result[3] == {"role": "assistant", "content": "Here are the results."}
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_assistant_with_no_text_content() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert "tool_calls" not in result[0]
|
||||
|
||||
|
||||
def test_strip_tool_content_passes_through_non_tool_messages() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
assert result == messages
|
||||
|
||||
|
||||
def test_strip_tool_content_handles_list_content_blocks() -> None:
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Searching now."}],
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "result A"},
|
||||
{"type": "text", "text": "result B"},
|
||||
],
|
||||
"tool_call_id": "tc_1",
|
||||
},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Assistant: list content flattened + tool call appended
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "Searching now." in result[0]["content"]
|
||||
assert "[Tool Call]" in result[0]["content"]
|
||||
assert isinstance(result[0]["content"], str)
|
||||
|
||||
# Tool: list content flattened into user message
|
||||
assert result[1]["role"] == "user"
|
||||
assert "result A" in result[1]["content"]
|
||||
assert "result B" in result[1]["content"]
|
||||
assert isinstance(result[1]["content"], str)
|
||||
|
||||
|
||||
def test_strip_tool_content_merges_consecutive_tool_results() -> None:
|
||||
"""Bedrock requires strict user/assistant alternation. Multiple parallel
|
||||
tool results must be merged into a single user message."""
|
||||
from onyx.llm.multi_llm import _strip_tool_content_from_messages
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "weather and news?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search_weather", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "tc_2",
|
||||
"type": "function",
|
||||
"function": {"name": "search_news", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "sunny 72F", "tool_call_id": "tc_1"},
|
||||
{"role": "tool", "content": "headline news", "tool_call_id": "tc_2"},
|
||||
{"role": "assistant", "content": "Here are the results."},
|
||||
]
|
||||
|
||||
result = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# user, assistant (flattened), user (merged tool results), assistant
|
||||
assert len(result) == 4
|
||||
roles = [m["role"] for m in result]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
|
||||
# Both tool results merged into one user message
|
||||
merged = result[2]["content"]
|
||||
assert "tc_1" in merged
|
||||
assert "sunny 72F" in merged
|
||||
assert "tc_2" in merged
|
||||
assert "headline news" in merged
|
||||
|
||||
@@ -104,102 +104,3 @@ def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
|
||||
@@ -87,8 +87,7 @@ def test_python_tool_available_when_health_check_passes(
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
@@ -110,8 +109,7 @@ def test_python_tool_unavailable_when_health_check_fails(
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
3
cli/.gitignore
vendored
3
cli/.gitignore
vendored
@@ -1,3 +0,0 @@
|
||||
onyx-cli
|
||||
cli
|
||||
onyx.cli
|
||||
118
cli/README.md
118
cli/README.md
@@ -1,118 +0,0 @@
|
||||
# Onyx CLI
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
Or with uv:
|
||||
|
||||
```shell
|
||||
uv pip install onyx-cli
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
Run the interactive setup:
|
||||
|
||||
```shell
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for your Onyx server URL and API key, tests the connection, and saves config to `~/.config/onyx-cli/config.json`.
|
||||
|
||||
Environment variables override config file values:
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive chat (default)
|
||||
|
||||
```shell
|
||||
onyx-cli
|
||||
```
|
||||
|
||||
### One-shot question
|
||||
|
||||
```shell
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
onyx-cli ask --agent-id 5 "Summarize this topic"
|
||||
onyx-cli ask --json "Hello"
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--agent-id <int>` | Agent ID to use (overrides default) |
|
||||
| `--json` | Output raw NDJSON events instead of plain text |
|
||||
|
||||
### List agents
|
||||
|
||||
```shell
|
||||
onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `chat` | Launch the interactive chat TUI (default) |
|
||||
| `ask` | Ask a one-shot question (non-interactive) |
|
||||
| `agents` | List available agents |
|
||||
| `configure` | Configure server URL and API key |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show help message |
|
||||
| `/new` | Start a new chat session |
|
||||
| `/agent` | List and switch agents |
|
||||
| `/attach <path>` | Attach a file to next message |
|
||||
| `/sessions` | List recent chat sessions |
|
||||
| `/clear` | Clear the chat display |
|
||||
| `/configure` | Re-run connection setup |
|
||||
| `/connectors` | Open connectors in browser |
|
||||
| `/settings` | Open settings in browser |
|
||||
| `/quit` | Exit Onyx CLI |
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `Enter` | Send message |
|
||||
| `Escape` | Cancel current generation |
|
||||
| `Ctrl+O` | Toggle source citations |
|
||||
| `Ctrl+D` | Quit (press twice) |
|
||||
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
|
||||
| `Page Up` / `Page Down` | Scroll half page |
|
||||
|
||||
## Building from Source
|
||||
|
||||
Requires [Go 1.24+](https://go.dev/dl/).
|
||||
|
||||
```shell
|
||||
cd cli
|
||||
go build -o onyx-cli .
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```shell
|
||||
# Run tests
|
||||
go test ./...
|
||||
|
||||
# Build
|
||||
go build -o onyx-cli .
|
||||
|
||||
# Lint
|
||||
staticcheck ./...
|
||||
```
|
||||
@@ -1,63 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAgentsCmd() *cobra.Command {
|
||||
var agentsJSON bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
data, err := json.MarshalIndent(agents, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal agents: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(agents) == 0 {
|
||||
fmt.Println("No agents available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tNAME\tDESCRIPTION")
|
||||
for _, a := range agents {
|
||||
desc := a.Description
|
||||
if len(desc) > 60 {
|
||||
desc = desc[:57] + "..."
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "%d\t%s\t%s\n", a.ID, a.Name, desc)
|
||||
}
|
||||
_ = w.Flush()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&agentsJSON, "json", false, "Output agents as JSON")
|
||||
|
||||
return cmd
|
||||
}
|
||||
103
cli/cmd/ask.go
103
cli/cmd/ask.go
@@ -1,103 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
ch := client.SendMessageStream(
|
||||
context.Background(),
|
||||
question,
|
||||
nil,
|
||||
agentID,
|
||||
&parentID,
|
||||
nil,
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
for event := range ch {
|
||||
if askJSON {
|
||||
wrapped := struct {
|
||||
Type string `json:"type"`
|
||||
Event models.StreamEvent `json:"event"`
|
||||
}{
|
||||
Type: event.EventType(),
|
||||
Event: event,
|
||||
}
|
||||
data, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling event: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
if _, ok := event.(models.ErrorEvent); ok {
|
||||
lastErr = fmt.Errorf("%s", event.(models.ErrorEvent).Error)
|
||||
}
|
||||
if _, ok := event.(models.StopEvent); ok {
|
||||
gotStop = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
case models.ErrorEvent:
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
return cmd
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newChatCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
// First-run: onboarding
|
||||
if !config.ConfigExists() || !cfg.IsConfigured() {
|
||||
result := onboarding.Run(&cfg)
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
_, err := p.Run()
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newConfigureCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
// Package cmd implements Cobra CLI commands for the Onyx CLI.
|
||||
package cmd
|
||||
|
||||
import "github.com/spf13/cobra"
|
||||
|
||||
// Version and Commit are set via ldflags at build time.
|
||||
var (
|
||||
Version string
|
||||
Commit string
|
||||
)
|
||||
|
||||
func fullVersion() string {
|
||||
if Commit != "" && Commit != "none" && len(Commit) > 7 {
|
||||
return Version + " (" + Commit[:7] + ")"
|
||||
}
|
||||
return Version
|
||||
}
|
||||
|
||||
// Execute creates and runs the root command.
|
||||
func Execute() error {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "onyx-cli",
|
||||
Short: "Terminal UI for chatting with Onyx",
|
||||
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
|
||||
Version: fullVersion(),
|
||||
}
|
||||
|
||||
// Register subcommands
|
||||
chatCmd := newChatCmd()
|
||||
rootCmd.AddCommand(chatCmd)
|
||||
rootCmd.AddCommand(newAskCmd())
|
||||
rootCmd.AddCommand(newAgentsCmd())
|
||||
rootCmd.AddCommand(newConfigureCmd())
|
||||
rootCmd.AddCommand(newValidateConfigCmd())
|
||||
|
||||
// Default command is chat
|
||||
rootCmd.RunE = chatCmd.RunE
|
||||
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newValidateConfigCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "validate-config",
|
||||
Short: "Validate configuration and test server connection",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
|
||||
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(); err != nil {
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
45
cli/go.mod
45
cli/go.mod
@@ -1,45 +0,0 @@
|
||||
module github.com/onyx-dot-app/onyx/cli
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.3.4
|
||||
github.com/charmbracelet/glamour v0.8.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/spf13/cobra v1.9.1
|
||||
golang.org/x/term v0.22.0
|
||||
golang.org/x/text v0.34.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yuin/goldmark v1.7.4 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.3 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
)
|
||||
94
cli/go.sum
94
cli/go.sum
@@ -1,94 +0,0 @@
|
||||
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
|
||||
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
||||
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
|
||||
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
|
||||
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
|
||||
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
|
||||
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -1,279 +0,0 @@
|
||||
// Package api provides the HTTP client for communicating with the Onyx server.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Client is the Onyx API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client // default 30s timeout for quick requests
|
||||
longHTTPClient *http.Client // 5min timeout for streaming/uploads
|
||||
}
|
||||
|
||||
// NewClient creates a new API client from config.
|
||||
func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
apiKey: cfg.APIKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
},
|
||||
longHTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the client's config.
|
||||
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
|
||||
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
|
||||
c.apiKey = cfg.APIKey
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), method, c.baseURL+path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(method, path string, reqBody any, result any) error {
|
||||
var body io.Reader
|
||||
if reqBody != nil {
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := c.newRequest(method, path, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reqBody != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
return json.NewDecoder(resp.Body).Decode(result)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the server is reachable and credentials are valid.
|
||||
// Returns nil on success, or an error with a descriptive message on failure.
|
||||
func (c *Client) TestConnection() error {
|
||||
// Step 1: Basic reachability
|
||||
req, err := c.newRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
serverHeader := strings.ToLower(resp.Header.Get("Server"))
|
||||
|
||||
if resp.StatusCode == 403 {
|
||||
if strings.Contains(serverHeader, "awselb") || strings.Contains(serverHeader, "amazons3") {
|
||||
return fmt.Errorf("blocked by AWS load balancer (HTTP 403 on all requests).\n Your IP address may not be in the ALB's security group or WAF allowlist")
|
||||
}
|
||||
return fmt.Errorf("HTTP 403 on base URL — the server is blocking all traffic.\n This is likely a firewall, WAF, or IP allowlist restriction")
|
||||
}
|
||||
|
||||
// Step 2: Authenticated check
|
||||
req2, err := c.newRequest("GET", "/api/me", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
|
||||
resp2, err := c.longHTTPClient.Do(req2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
defer func() { _ = resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode == 200 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp2.Body, 300))
|
||||
body := string(bodyBytes)
|
||||
isHTML := strings.HasPrefix(strings.TrimSpace(body), "<")
|
||||
respServer := strings.ToLower(resp2.Header.Get("Server"))
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
if body != "" {
|
||||
detail += fmt.Sprintf("\n Response: %s", body)
|
||||
}
|
||||
return fmt.Errorf("%s", detail)
|
||||
}
|
||||
|
||||
// ListAgents returns visible agents.
|
||||
func (c *Client) ListAgents() ([]models.AgentSummary, error) {
|
||||
var raw []models.AgentSummary
|
||||
if err := c.doJSON("GET", "/api/persona", nil, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []models.AgentSummary
|
||||
for _, p := range raw {
|
||||
if p.IsVisible {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListChatSessions returns recent chat sessions.
|
||||
func (c *Client) ListChatSessions() ([]models.ChatSessionDetails, error) {
|
||||
var resp struct {
|
||||
Sessions []models.ChatSessionDetails `json:"sessions"`
|
||||
}
|
||||
if err := c.doJSON("GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Sessions, nil
|
||||
}
|
||||
|
||||
// GetChatSession returns full details for a session.
|
||||
func (c *Client) GetChatSession(sessionID string) (*models.ChatSessionDetailResponse, error) {
|
||||
var resp models.ChatSessionDetailResponse
|
||||
if err := c.doJSON("GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RenameChatSession renames a session. If name is empty, the backend auto-generates one.
|
||||
func (c *Client) RenameChatSession(sessionID string, name *string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
if name != nil {
|
||||
payload["name"] = *name
|
||||
}
|
||||
var resp struct {
|
||||
NewName string `json:"new_name"`
|
||||
}
|
||||
if err := c.doJSON("PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.NewName, nil
|
||||
}
|
||||
|
||||
// UploadFile uploads a file and returns a file descriptor.
|
||||
func (c *Client) UploadFile(filePath string) (*models.FileDescriptorPayload, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
part, err := writer.CreateFormFile("files", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = writer.Close()
|
||||
|
||||
req, err := c.newRequest("POST", "/api/user/projects/file/upload", &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(body)}
|
||||
}
|
||||
|
||||
var snapshot models.CategorizedFilesSnapshot
|
||||
if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(snapshot.UserFiles) == 0 {
|
||||
return nil, &OnyxAPIError{StatusCode: 400, Detail: "File upload returned no files"}
|
||||
}
|
||||
|
||||
uf := snapshot.UserFiles[0]
|
||||
return &models.FileDescriptorPayload{
|
||||
ID: uf.FileID,
|
||||
Type: uf.ChatFileType,
|
||||
Name: filepath.Base(filePath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StopChatSession sends a stop signal for a streaming session (best-effort).
|
||||
func (c *Client) StopChatSession(sessionID string) {
|
||||
req, err := c.newRequest("POST", "/api/chat/stop-chat-session/"+sessionID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package api
|
||||
|
||||
import "fmt"
|
||||
|
||||
// OnyxAPIError is returned when an Onyx API call fails.
|
||||
type OnyxAPIError struct {
|
||||
StatusCode int
|
||||
Detail string
|
||||
}
|
||||
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/parser"
|
||||
)
|
||||
|
||||
// StreamEventMsg wraps a StreamEvent for Bubble Tea.
|
||||
type StreamEventMsg struct {
|
||||
Event models.StreamEvent
|
||||
}
|
||||
|
||||
// StreamDoneMsg signals the stream has ended.
|
||||
type StreamDoneMsg struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SendMessageStream starts streaming a chat message response.
|
||||
// It reads NDJSON lines, parses them, and sends events on the returned channel.
|
||||
// The goroutine stops when ctx is cancelled or the stream ends.
|
||||
func (c *Client) SendMessageStream(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
chatSessionID *string,
|
||||
agentID int,
|
||||
parentMessageID *int,
|
||||
fileDescriptors []models.FileDescriptorPayload,
|
||||
) <-chan models.StreamEvent {
|
||||
ch := make(chan models.StreamEvent, 64)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
payload := models.SendMessagePayload{
|
||||
Message: message,
|
||||
ParentMessageID: parentMessageID,
|
||||
FileDescriptors: fileDescriptors,
|
||||
Origin: "api",
|
||||
IncludeCitations: true,
|
||||
Stream: true,
|
||||
}
|
||||
if payload.FileDescriptors == nil {
|
||||
payload.FileDescriptors = []models.FileDescriptorPayload{}
|
||||
}
|
||||
|
||||
if chatSessionID != nil {
|
||||
payload.ChatSessionID = chatSessionID
|
||||
} else {
|
||||
payload.ChatSessionInfo = &models.ChatSessionCreationInfo{AgentID: agentID}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("marshal error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
req.ContentLength = int64(len(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // cancelled
|
||||
}
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("connection error: %v", err), IsRetryable: true}
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var respBody [4096]byte
|
||||
n, _ := resp.Body.Read(respBody[:])
|
||||
ch <- models.ErrorEvent{
|
||||
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody[:n])),
|
||||
IsRetryable: resp.StatusCode >= 500,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
event := parser.ParseStreamLine(scanner.Text())
|
||||
if event != nil {
|
||||
select {
|
||||
case ch <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && ctx.Err() == nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("stream read error: %v", err), IsRetryable: true}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// WaitForStreamEvent returns a tea.Cmd that reads one event from the channel.
|
||||
// On channel close, it returns StreamDoneMsg.
|
||||
func WaitForStreamEvent(ch <-chan models.StreamEvent) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
event, ok := <-ch
|
||||
if !ok {
|
||||
return StreamDoneMsg{}
|
||||
}
|
||||
return StreamEventMsg{Event: event}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
)
|
||||
|
||||
// OnyxCliConfig holds the CLI configuration.
|
||||
type OnyxCliConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a config with default values.
|
||||
func DefaultConfig() OnyxCliConfig {
|
||||
return OnyxCliConfig{
|
||||
ServerURL: "https://cloud.onyx.app",
|
||||
APIKey: "",
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// IsConfigured returns true if the config has an API key.
|
||||
func (c OnyxCliConfig) IsConfigured() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
// configDir returns ~/.config/onyx-cli
|
||||
func configDir() string {
|
||||
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
|
||||
return filepath.Join(xdg, "onyx-cli")
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".", ".config", "onyx-cli")
|
||||
}
|
||||
return filepath.Join(home, ".config", "onyx-cli")
|
||||
}
|
||||
|
||||
// ConfigFilePath returns the full path to the config file.
|
||||
func ConfigFilePath() string {
|
||||
return filepath.Join(configDir(), "config.json")
|
||||
}
|
||||
|
||||
// ConfigExists checks if the config file exists on disk.
|
||||
func ConfigExists() bool {
|
||||
_, err := os.Stat(ConfigFilePath())
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
if err == nil {
|
||||
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
}
|
||||
if v := os.Getenv(EnvAPIKey); v != "" {
|
||||
cfg.APIKey = v
|
||||
}
|
||||
if v := os.Getenv(EnvAgentID); v != "" {
|
||||
if id, err := strconv.Atoi(v); err == nil {
|
||||
cfg.DefaultAgentID = id
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Save writes the config to disk, creating parent directories if needed.
|
||||
func Save(cfg OnyxCliConfig) error {
|
||||
dir := configDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(ConfigFilePath(), data, 0o600)
|
||||
}
|
||||
@@ -1,215 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func clearEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID} {
|
||||
t.Setenv(key, "")
|
||||
if err := os.Unsetenv(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeConfig(t *testing.T, dir string, data []byte) {
|
||||
t.Helper()
|
||||
onyxDir := filepath.Join(dir, "onyx-cli")
|
||||
if err := os.MkdirAll(onyxDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default server URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty API key, got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("expected default agent ID 0, got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsConfigured(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.IsConfigured() {
|
||||
t.Error("empty config should not be configured")
|
||||
}
|
||||
cfg.APIKey = "some-key"
|
||||
if !cfg.IsConfigured() {
|
||||
t.Error("config with API key should be configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty key, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://my-onyx.example.com",
|
||||
"api_key": "test-key-123",
|
||||
"default_persona_id": 5,
|
||||
})
|
||||
writeConfig(t, dir, data)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://my-onyx.example.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "test-key-123" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 5 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCorruptFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
writeConfig(t, dir, []byte("not valid json {{{"))
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default URL on corrupt file, got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideServerURL(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvServerURL, "https://env-override.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-override.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAPIKey(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAPIKey, "env-key")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.APIKey != "env-key" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "42")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 42 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideInvalidAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "not-a-number")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverridesFileValues(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://file-url.com",
|
||||
"api_key": "file-key",
|
||||
})
|
||||
writeConfig(t, dir, data)
|
||||
|
||||
t.Setenv(EnvServerURL, "https://env-url.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-url.com" {
|
||||
t.Errorf("env should override file, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "file-key" {
|
||||
t.Errorf("file value should be kept, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndReload(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := OnyxCliConfig{
|
||||
ServerURL: "https://saved.example.com",
|
||||
APIKey: "saved-key",
|
||||
DefaultAgentID: 10,
|
||||
}
|
||||
if err := Save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded := Load()
|
||||
if loaded.ServerURL != "https://saved.example.com" {
|
||||
t.Errorf("got %s", loaded.ServerURL)
|
||||
}
|
||||
if loaded.APIKey != "saved-key" {
|
||||
t.Errorf("got %s", loaded.APIKey)
|
||||
}
|
||||
if loaded.DefaultAgentID != 10 {
|
||||
t.Errorf("got %d", loaded.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCreatesParentDirs(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
nested := filepath.Join(dir, "deep", "nested")
|
||||
t.Setenv("XDG_CONFIG_HOME", nested)
|
||||
|
||||
if err := Save(OnyxCliConfig{APIKey: "test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !ConfigExists() {
|
||||
t.Error("config file should exist after save")
|
||||
}
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package models
|
||||
|
||||
// StreamEvent is the interface for all parsed stream events.
|
||||
type StreamEvent interface {
|
||||
EventType() string
|
||||
}
|
||||
|
||||
// Event type constants matching the Python StreamEventType enum.
|
||||
const (
|
||||
EventSessionCreated = "session_created"
|
||||
EventMessageIDInfo = "message_id_info"
|
||||
EventStop = "stop"
|
||||
EventError = "error"
|
||||
EventMessageStart = "message_start"
|
||||
EventMessageDelta = "message_delta"
|
||||
EventSearchStart = "search_tool_start"
|
||||
EventSearchQueries = "search_tool_queries_delta"
|
||||
EventSearchDocuments = "search_tool_documents_delta"
|
||||
EventReasoningStart = "reasoning_start"
|
||||
EventReasoningDelta = "reasoning_delta"
|
||||
EventReasoningDone = "reasoning_done"
|
||||
EventCitationInfo = "citation_info"
|
||||
EventOpenURLStart = "open_url_start"
|
||||
EventImageGenStart = "image_generation_start"
|
||||
EventPythonToolStart = "python_tool_start"
|
||||
EventCustomToolStart = "custom_tool_start"
|
||||
EventFileReaderStart = "file_reader_start"
|
||||
EventDeepResearchPlan = "deep_research_plan_start"
|
||||
EventDeepResearchDelta = "deep_research_plan_delta"
|
||||
EventResearchAgentStart = "research_agent_start"
|
||||
EventIntermediateReport = "intermediate_report_start"
|
||||
EventIntermediateReportDt = "intermediate_report_delta"
|
||||
EventUnknown = "unknown"
|
||||
)
|
||||
|
||||
// SessionCreatedEvent is emitted when a new chat session is created.
|
||||
type SessionCreatedEvent struct {
|
||||
ChatSessionID string
|
||||
}
|
||||
|
||||
func (e SessionCreatedEvent) EventType() string { return EventSessionCreated }
|
||||
|
||||
// MessageIDEvent carries the user and agent message IDs.
|
||||
type MessageIDEvent struct {
|
||||
UserMessageID *int
|
||||
ReservedAgentMessageID int
|
||||
}
|
||||
|
||||
func (e MessageIDEvent) EventType() string { return EventMessageIDInfo }
|
||||
|
||||
// StopEvent signals the end of a stream.
|
||||
type StopEvent struct {
|
||||
Placement *Placement
|
||||
StopReason *string
|
||||
}
|
||||
|
||||
func (e StopEvent) EventType() string { return EventStop }
|
||||
|
||||
// ErrorEvent signals an error.
|
||||
type ErrorEvent struct {
|
||||
Placement *Placement
|
||||
Error string
|
||||
StackTrace *string
|
||||
IsRetryable bool
|
||||
}
|
||||
|
||||
func (e ErrorEvent) EventType() string { return EventError }
|
||||
|
||||
// MessageStartEvent signals the beginning of an agent message.
|
||||
type MessageStartEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e MessageStartEvent) EventType() string { return EventMessageStart }
|
||||
|
||||
// MessageDeltaEvent carries a token of agent content.
|
||||
type MessageDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e MessageDeltaEvent) EventType() string { return EventMessageDelta }
|
||||
|
||||
// SearchStartEvent signals the beginning of a search.
|
||||
type SearchStartEvent struct {
|
||||
Placement *Placement
|
||||
IsInternetSearch bool
|
||||
}
|
||||
|
||||
func (e SearchStartEvent) EventType() string { return EventSearchStart }
|
||||
|
||||
// SearchQueriesEvent carries search queries.
|
||||
type SearchQueriesEvent struct {
|
||||
Placement *Placement
|
||||
Queries []string
|
||||
}
|
||||
|
||||
func (e SearchQueriesEvent) EventType() string { return EventSearchQueries }
|
||||
|
||||
// SearchDocumentsEvent carries found documents.
|
||||
type SearchDocumentsEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e SearchDocumentsEvent) EventType() string { return EventSearchDocuments }
|
||||
|
||||
// ReasoningStartEvent signals the beginning of a reasoning block.
|
||||
type ReasoningStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningStartEvent) EventType() string { return EventReasoningStart }
|
||||
|
||||
// ReasoningDeltaEvent carries reasoning text.
|
||||
type ReasoningDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Reasoning string
|
||||
}
|
||||
|
||||
func (e ReasoningDeltaEvent) EventType() string { return EventReasoningDelta }
|
||||
|
||||
// ReasoningDoneEvent signals the end of reasoning.
|
||||
type ReasoningDoneEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningDoneEvent) EventType() string { return EventReasoningDone }
|
||||
|
||||
// CitationEvent carries citation info.
|
||||
type CitationEvent struct {
|
||||
Placement *Placement
|
||||
CitationNumber int
|
||||
DocumentID string
|
||||
}
|
||||
|
||||
func (e CitationEvent) EventType() string { return EventCitationInfo }
|
||||
|
||||
// ToolStartEvent signals the start of a tool usage.
|
||||
type ToolStartEvent struct {
|
||||
Placement *Placement
|
||||
Type string // The specific event type (e.g. "open_url_start")
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e ToolStartEvent) EventType() string { return e.Type }
|
||||
|
||||
// DeepResearchPlanStartEvent signals the start of a deep research plan.
|
||||
type DeepResearchPlanStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanStartEvent) EventType() string { return EventDeepResearchPlan }
|
||||
|
||||
// DeepResearchPlanDeltaEvent carries deep research plan content.
|
||||
type DeepResearchPlanDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanDeltaEvent) EventType() string { return EventDeepResearchDelta }
|
||||
|
||||
// ResearchAgentStartEvent signals a research sub-task.
|
||||
type ResearchAgentStartEvent struct {
|
||||
Placement *Placement
|
||||
ResearchTask string
|
||||
}
|
||||
|
||||
func (e ResearchAgentStartEvent) EventType() string { return EventResearchAgentStart }
|
||||
|
||||
// IntermediateReportStartEvent signals the start of an intermediate report.
|
||||
type IntermediateReportStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e IntermediateReportStartEvent) EventType() string { return EventIntermediateReport }
|
||||
|
||||
// IntermediateReportDeltaEvent carries intermediate report content.
|
||||
type IntermediateReportDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e IntermediateReportDeltaEvent) EventType() string { return EventIntermediateReportDt }
|
||||
|
||||
// UnknownEvent is a catch-all for unrecognized stream data.
|
||||
type UnknownEvent struct {
|
||||
Placement *Placement
|
||||
RawData map[string]any
|
||||
}
|
||||
|
||||
func (e UnknownEvent) EventType() string { return EventUnknown }
|
||||
@@ -1,112 +0,0 @@
|
||||
// Package models defines API request/response types for the Onyx CLI.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// AgentSummary represents an agent from the API.
|
||||
type AgentSummary struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IsDefaultPersona bool `json:"is_default_persona"`
|
||||
IsVisible bool `json:"is_visible"`
|
||||
}
|
||||
|
||||
// ChatSessionSummary is a brief session listing.
|
||||
type ChatSessionSummary struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created time.Time `json:"time_created"`
|
||||
}
|
||||
|
||||
// ChatSessionDetails is a session with timestamps as strings.
|
||||
type ChatSessionDetails struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created string `json:"time_created"`
|
||||
Updated string `json:"time_updated"`
|
||||
}
|
||||
|
||||
// ChatMessageDetail is a single message in a session.
|
||||
type ChatMessageDetail struct {
|
||||
MessageID int `json:"message_id"`
|
||||
ParentMessage *int `json:"parent_message"`
|
||||
LatestChildMessage *int `json:"latest_child_message"`
|
||||
Message string `json:"message"`
|
||||
MessageType string `json:"message_type"`
|
||||
TimeSent string `json:"time_sent"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
// ChatSessionDetailResponse is the full session detail from the API.
|
||||
type ChatSessionDetailResponse struct {
|
||||
ChatSessionID string `json:"chat_session_id"`
|
||||
Description *string `json:"description"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
AgentName *string `json:"persona_name"`
|
||||
Messages []ChatMessageDetail `json:"messages"`
|
||||
}
|
||||
|
||||
// ChatFileType represents a file type for uploads.
|
||||
type ChatFileType string
|
||||
|
||||
const (
|
||||
ChatFileImage ChatFileType = "image"
|
||||
ChatFileDoc ChatFileType = "document"
|
||||
ChatFilePlainText ChatFileType = "plain_text"
|
||||
ChatFileCSV ChatFileType = "csv"
|
||||
)
|
||||
|
||||
// FileDescriptorPayload is a file descriptor for send-message requests.
|
||||
type FileDescriptorPayload struct {
|
||||
ID string `json:"id"`
|
||||
Type ChatFileType `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// UserFileSnapshot represents an uploaded file.
|
||||
type UserFileSnapshot struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FileID string `json:"file_id"`
|
||||
ChatFileType ChatFileType `json:"chat_file_type"`
|
||||
}
|
||||
|
||||
// CategorizedFilesSnapshot is the response from file upload.
|
||||
type CategorizedFilesSnapshot struct {
|
||||
UserFiles []UserFileSnapshot `json:"user_files"`
|
||||
}
|
||||
|
||||
// ChatSessionCreationInfo is included when creating a new session inline.
|
||||
type ChatSessionCreationInfo struct {
|
||||
AgentID int `json:"persona_id"`
|
||||
}
|
||||
|
||||
// SendMessagePayload is the request body for POST /api/chat/send-chat-message.
|
||||
type SendMessagePayload struct {
|
||||
Message string `json:"message"`
|
||||
ChatSessionID *string `json:"chat_session_id,omitempty"`
|
||||
ChatSessionInfo *ChatSessionCreationInfo `json:"chat_session_info,omitempty"`
|
||||
ParentMessageID *int `json:"parent_message_id"`
|
||||
FileDescriptors []FileDescriptorPayload `json:"file_descriptors"`
|
||||
Origin string `json:"origin"`
|
||||
IncludeCitations bool `json:"include_citations"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// SearchDoc represents a document found during search.
|
||||
type SearchDoc struct {
|
||||
DocumentID string `json:"document_id"`
|
||||
SemanticIdentifier string `json:"semantic_identifier"`
|
||||
Link *string `json:"link"`
|
||||
SourceType string `json:"source_type"`
|
||||
}
|
||||
|
||||
// Placement indicates where a stream event belongs in the conversation.
|
||||
type Placement struct {
|
||||
TurnIndex int `json:"turn_index"`
|
||||
TabIndex int `json:"tab_index"`
|
||||
SubTurnIndex *int `json:"sub_turn_index"`
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
// Package onboarding handles the first-run setup flow for Onyx CLI.
|
||||
package onboarding
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Aliases for shared styles.
|
||||
var (
|
||||
boldStyle = util.BoldStyle
|
||||
dimStyle = util.DimStyle
|
||||
greenStyle = util.GreenStyle
|
||||
redStyle = util.RedStyle
|
||||
yellowStyle = util.YellowStyle
|
||||
)
|
||||
|
||||
func getTermSize() (int, int) {
|
||||
w, h, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return 80, 24
|
||||
}
|
||||
return w, h
|
||||
}
|
||||
|
||||
// Run executes the interactive onboarding flow.
|
||||
// Returns the validated config, or nil if the user cancels.
|
||||
func Run(existing *config.OnyxCliConfig) *config.OnyxCliConfig {
|
||||
cfg := config.DefaultConfig()
|
||||
if existing != nil {
|
||||
cfg = *existing
|
||||
}
|
||||
|
||||
w, h := getTermSize()
|
||||
fmt.Print(tui.RenderSplashOnboarding(w, h))
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println(" Welcome to " + boldStyle.Render("Onyx CLI") + ".")
|
||||
fmt.Println()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
// Server URL
|
||||
serverURL := prompt(reader, " Onyx server URL", cfg.ServerURL)
|
||||
if serverURL == "" {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(serverURL, "http://") && !strings.HasPrefix(serverURL, "https://") {
|
||||
fmt.Println(" " + redStyle.Render("Server URL must start with http:// or https://"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// API Key
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Need an API key? Press Enter to open the admin panel in your browser,"))
|
||||
fmt.Println(" " + dimStyle.Render("or paste your key below."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey := promptSecret(" API key", cfg.APIKey)
|
||||
|
||||
if apiKey == "" {
|
||||
// Open browser to API key page
|
||||
url := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
|
||||
fmt.Printf("\n Opening %s ...\n", url)
|
||||
util.OpenBrowser(url)
|
||||
fmt.Println(" " + dimStyle.Render("Copy your API key, then paste it here."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey = promptSecret(" API key", "")
|
||||
if apiKey == "" {
|
||||
fmt.Println("\n " + redStyle.Render("No API key provided. Exiting."))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Test connection
|
||||
cfg = config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: cfg.DefaultAgentID,
|
||||
}
|
||||
|
||||
fmt.Println("\n " + yellowStyle.Render("Testing connection..."))
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Connection failed.") + " " + err.Error())
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Run ") + boldStyle.Render("onyx-cli configure") + dimStyle.Render(" to try again."))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Could not save config: "+err.Error()))
|
||||
return nil
|
||||
}
|
||||
fmt.Println(" " + greenStyle.Render("Connected and authenticated."))
|
||||
fmt.Println()
|
||||
printQuickStart()
|
||||
return &cfg
|
||||
}
|
||||
|
||||
func promptSecret(label, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf("%s %s: ", label, dimStyle.Render("[hidden]"))
|
||||
} else {
|
||||
fmt.Printf("%s: ", label)
|
||||
}
|
||||
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println() // ReadPassword doesn't echo a newline
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
line := strings.TrimSpace(string(password))
|
||||
if line == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func prompt(reader *bufio.Reader, label, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf("%s %s: ", label, dimStyle.Render("["+defaultVal+"]"))
|
||||
} else {
|
||||
fmt.Printf("%s: ", label)
|
||||
}
|
||||
|
||||
line, err := reader.ReadString('\n')
|
||||
// ReadString may return partial data along with an error (e.g. EOF without newline)
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
return line
|
||||
}
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func printQuickStart() {
|
||||
fmt.Println(" " + boldStyle.Render("Quick start"))
|
||||
fmt.Println()
|
||||
fmt.Println(" Just type to chat with your Onyx agent.")
|
||||
fmt.Println()
|
||||
|
||||
rows := [][2]string{
|
||||
{"/help", "Show all commands"},
|
||||
{"/attach", "Attach a file"},
|
||||
{"/agent", "Switch agent"},
|
||||
{"/new", "New conversation"},
|
||||
{"/sessions", "Browse previous chats"},
|
||||
{"Esc", "Cancel generation"},
|
||||
{"Ctrl+D", "Quit"},
|
||||
}
|
||||
for _, r := range rows {
|
||||
fmt.Printf(" %-12s %s\n", boldStyle.Render(r[0]), dimStyle.Render(r[1]))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
// Package parser handles NDJSON stream parsing for Onyx chat responses.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// ParseStreamLine parses a single NDJSON line into a typed StreamEvent.
|
||||
// Returns nil for empty lines or unparseable content.
|
||||
func ParseStreamLine(line string) models.StreamEvent {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &data); err != nil {
|
||||
return models.ErrorEvent{Error: fmt.Sprintf("malformed stream data: %v", err), IsRetryable: false}
|
||||
}
|
||||
|
||||
// Case 1: CreateChatSessionID
|
||||
if _, ok := data["chat_session_id"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
sid, _ := data["chat_session_id"].(string)
|
||||
return models.SessionCreatedEvent{ChatSessionID: sid}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: MessageResponseIDInfo
|
||||
if _, ok := data["reserved_assistant_message_id"]; ok {
|
||||
reservedID := jsonInt(data["reserved_assistant_message_id"])
|
||||
var userMsgID *int
|
||||
if v, ok := data["user_message_id"]; ok && v != nil {
|
||||
id := jsonInt(v)
|
||||
userMsgID = &id
|
||||
}
|
||||
return models.MessageIDEvent{
|
||||
UserMessageID: userMsgID,
|
||||
ReservedAgentMessageID: reservedID,
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: StreamingError (top-level error without placement)
|
||||
if _, ok := data["error"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
errStr, _ := data["error"].(string)
|
||||
var stackTrace *string
|
||||
if st, ok := data["stack_trace"].(string); ok {
|
||||
stackTrace = &st
|
||||
}
|
||||
isRetryable := true
|
||||
if v, ok := data["is_retryable"].(bool); ok {
|
||||
isRetryable = v
|
||||
}
|
||||
return models.ErrorEvent{
|
||||
Error: errStr,
|
||||
StackTrace: stackTrace,
|
||||
IsRetryable: isRetryable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 4: Packet with placement + obj
|
||||
if rawPlacement, ok := data["placement"]; ok {
|
||||
if rawObj, ok := data["obj"]; ok {
|
||||
placement := parsePlacement(rawPlacement)
|
||||
obj, _ := rawObj.(map[string]any)
|
||||
if obj == nil {
|
||||
return models.UnknownEvent{Placement: placement, RawData: data}
|
||||
}
|
||||
return parsePacketObj(obj, placement)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
return models.UnknownEvent{RawData: data}
|
||||
}
|
||||
|
||||
func parsePlacement(raw interface{}) *models.Placement {
|
||||
m, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
p := &models.Placement{
|
||||
TurnIndex: jsonInt(m["turn_index"]),
|
||||
TabIndex: jsonInt(m["tab_index"]),
|
||||
}
|
||||
if v, ok := m["sub_turn_index"]; ok && v != nil {
|
||||
st := jsonInt(v)
|
||||
p.SubTurnIndex = &st
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func parsePacketObj(obj map[string]any, placement *models.Placement) models.StreamEvent {
|
||||
objType, _ := obj["type"].(string)
|
||||
|
||||
switch objType {
|
||||
case "stop":
|
||||
var reason *string
|
||||
if r, ok := obj["stop_reason"].(string); ok {
|
||||
reason = &r
|
||||
}
|
||||
return models.StopEvent{Placement: placement, StopReason: reason}
|
||||
|
||||
case "error":
|
||||
errMsg := "Unknown error"
|
||||
if e, ok := obj["exception"]; ok {
|
||||
errMsg = toString(e)
|
||||
}
|
||||
return models.ErrorEvent{Placement: placement, Error: errMsg, IsRetryable: true}
|
||||
|
||||
case "message_start":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["final_documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.MessageStartEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "message_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.MessageDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "search_tool_start":
|
||||
isInternet, _ := obj["is_internet_search"].(bool)
|
||||
return models.SearchStartEvent{Placement: placement, IsInternetSearch: isInternet}
|
||||
|
||||
case "search_tool_queries_delta":
|
||||
var queries []string
|
||||
if raw, ok := obj["queries"].([]any); ok {
|
||||
for _, q := range raw {
|
||||
if s, ok := q.(string); ok {
|
||||
queries = append(queries, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return models.SearchQueriesEvent{Placement: placement, Queries: queries}
|
||||
|
||||
case "search_tool_documents_delta":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.SearchDocumentsEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "reasoning_start":
|
||||
return models.ReasoningStartEvent{Placement: placement}
|
||||
|
||||
case "reasoning_delta":
|
||||
reasoning, _ := obj["reasoning"].(string)
|
||||
return models.ReasoningDeltaEvent{Placement: placement, Reasoning: reasoning}
|
||||
|
||||
case "reasoning_done":
|
||||
return models.ReasoningDoneEvent{Placement: placement}
|
||||
|
||||
case "citation_info":
|
||||
return models.CitationEvent{
|
||||
Placement: placement,
|
||||
CitationNumber: jsonInt(obj["citation_number"]),
|
||||
DocumentID: jsonString(obj["document_id"]),
|
||||
}
|
||||
|
||||
case "open_url_start", "image_generation_start", "python_tool_start", "file_reader_start":
|
||||
toolName := strings.ReplaceAll(strings.TrimSuffix(objType, "_start"), "_", " ")
|
||||
toolName = cases.Title(language.English).String(toolName)
|
||||
return models.ToolStartEvent{Placement: placement, Type: objType, ToolName: toolName}
|
||||
|
||||
case "custom_tool_start":
|
||||
toolName := jsonString(obj["tool_name"])
|
||||
if toolName == "" {
|
||||
toolName = "Custom Tool"
|
||||
}
|
||||
return models.ToolStartEvent{Placement: placement, Type: models.EventCustomToolStart, ToolName: toolName}
|
||||
|
||||
case "deep_research_plan_start":
|
||||
return models.DeepResearchPlanStartEvent{Placement: placement}
|
||||
|
||||
case "deep_research_plan_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.DeepResearchPlanDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "research_agent_start":
|
||||
task, _ := obj["research_task"].(string)
|
||||
return models.ResearchAgentStartEvent{Placement: placement, ResearchTask: task}
|
||||
|
||||
case "intermediate_report_start":
|
||||
return models.IntermediateReportStartEvent{Placement: placement}
|
||||
|
||||
case "intermediate_report_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.IntermediateReportDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
default:
|
||||
return models.UnknownEvent{Placement: placement, RawData: obj}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSearchDocs(raw []any) []models.SearchDoc {
|
||||
var docs []models.SearchDoc
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
doc := models.SearchDoc{
|
||||
DocumentID: jsonString(m["document_id"]),
|
||||
SemanticIdentifier: jsonString(m["semantic_identifier"]),
|
||||
SourceType: jsonString(m["source_type"]),
|
||||
}
|
||||
if link, ok := m["link"].(string); ok {
|
||||
doc.Link = &link
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
func jsonInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func jsonString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
@@ -1,419 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
func TestEmptyLineReturnsNil(t *testing.T) {
|
||||
for _, line := range []string{"", " ", "\n"} {
|
||||
if ParseStreamLine(line) != nil {
|
||||
t.Errorf("expected nil for %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidJSONReturnsErrorEvent(t *testing.T) {
|
||||
for _, line := range []string{"not json", "{broken"} {
|
||||
event := ParseStreamLine(line)
|
||||
if event == nil {
|
||||
t.Errorf("expected ErrorEvent for %q, got nil", line)
|
||||
continue
|
||||
}
|
||||
if _, ok := event.(models.ErrorEvent); !ok {
|
||||
t.Errorf("expected ErrorEvent for %q, got %T", line, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCreated(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"chat_session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SessionCreatedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SessionCreatedEvent, got %T", event)
|
||||
}
|
||||
if e.ChatSessionID != "550e8400-e29b-41d4-a716-446655440000" {
|
||||
t.Errorf("got %s", e.ChatSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfo(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": 1,
|
||||
"reserved_assistant_message_id": 2,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID == nil || *e.UserMessageID != 1 {
|
||||
t.Errorf("expected user_message_id=1")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 2 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfoNullUserID(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": nil,
|
||||
"reserved_assistant_message_id": 5,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID != nil {
|
||||
t.Error("expected nil user_message_id")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 5 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelError(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Rate limit exceeded",
|
||||
"stack_trace": "...",
|
||||
"is_retryable": true,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Rate limit exceeded" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if e.StackTrace == nil || *e.StackTrace != "..." {
|
||||
t.Error("expected stack_trace")
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected retryable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelErrorMinimal(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Something broke",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Something broke" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected default retryable=true")
|
||||
}
|
||||
}
|
||||
|
||||
func makePacket(obj map[string]interface{}, turnIndex, tabIndex int) string {
|
||||
return mustJSON(map[string]interface{}{
|
||||
"placement": map[string]interface{}{"turn_index": turnIndex, "tab_index": tabIndex},
|
||||
"obj": obj,
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPacket(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop", "stop_reason": "completed"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason == nil || *e.StopReason != "completed" {
|
||||
t.Error("expected stop_reason=completed")
|
||||
}
|
||||
if e.Placement == nil || e.Placement.TurnIndex != 0 {
|
||||
t.Error("expected placement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopPacketNoReason(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason != nil {
|
||||
t.Error("expected nil stop_reason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
_, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStartWithDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"final_documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "doc1", "semantic_identifier": "Doc 1"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 1 || e.Documents[0].DocumentID != "doc1" {
|
||||
t.Error("expected 1 document with id doc1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": "Hello"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Hello" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeltaEmpty(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": ""}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "" {
|
||||
t.Errorf("expected empty, got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_start", "is_internet_search": true,
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchStartEvent, got %T", event)
|
||||
}
|
||||
if !e.IsInternetSearch {
|
||||
t.Error("expected internet search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolQueries(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_queries_delta",
|
||||
"queries": []interface{}{"query 1", "query 2"},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchQueriesEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchQueriesEvent, got %T", event)
|
||||
}
|
||||
if len(e.Queries) != 2 || e.Queries[0] != "query 1" {
|
||||
t.Error("unexpected queries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_documents_delta",
|
||||
"documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "d1", "semantic_identifier": "First Doc", "link": "http://example.com"},
|
||||
map[string]interface{}{"document_id": "d2", "semantic_identifier": "Second Doc"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchDocumentsEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchDocumentsEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 2 {
|
||||
t.Errorf("expected 2 docs, got %d", len(e.Documents))
|
||||
}
|
||||
if e.Documents[0].Link == nil || *e.Documents[0].Link != "http://example.com" {
|
||||
t.Error("expected link on first doc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningStartEvent); !ok {
|
||||
t.Fatalf("expected ReasoningStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "reasoning_delta", "reasoning": "Let me think...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ReasoningDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReasoningDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Reasoning != "Let me think..." {
|
||||
t.Errorf("got %s", e.Reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDone(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_done"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningDoneEvent); !ok {
|
||||
t.Fatalf("expected ReasoningDoneEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationInfo(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "citation_info", "citation_number": 1, "document_id": "doc_abc",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.CitationEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected CitationEvent, got %T", event)
|
||||
}
|
||||
if e.CitationNumber != 1 || e.DocumentID != "doc_abc" {
|
||||
t.Errorf("got %d, %s", e.CitationNumber, e.DocumentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenURLStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "open_url_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.Type != "open_url_start" {
|
||||
t.Errorf("got type %s", e.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "python_tool_start", "code": "print('hi')",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "Python Tool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "custom_tool_start", "tool_name": "MyTool",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "MyTool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepResearchPlanDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "deep_research_plan_delta", "content": "Step 1: ...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.DeepResearchPlanDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected DeepResearchPlanDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Step 1: ..." {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResearchAgentStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "research_agent_start", "research_task": "Find info about X",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ResearchAgentStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResearchAgentStartEvent, got %T", event)
|
||||
}
|
||||
if e.ResearchTask != "Find info about X" {
|
||||
t.Errorf("got %s", e.ResearchTask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntermediateReportDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "intermediate_report_delta", "content": "Report text",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.IntermediateReportDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected IntermediateReportDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Report text" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownPacketType(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "section_end"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownTopLevel(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{"some_unknown_field": "value"})
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlacementPreserved(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_delta", "content": "x",
|
||||
}, 3, 1)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Placement == nil {
|
||||
t.Fatal("expected placement")
|
||||
}
|
||||
if e.Placement.TurnIndex != 3 || e.Placement.TabIndex != 1 {
|
||||
t.Errorf("got turn=%d tab=%d", e.Placement.TurnIndex, e.Placement.TabIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -1,627 +0,0 @@
|
||||
// Package tui implements the Bubble Tea TUI for Onyx CLI.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Model is the root Bubble Tea model.
|
||||
type Model struct {
|
||||
config config.OnyxCliConfig
|
||||
client *api.Client
|
||||
|
||||
viewport *viewport
|
||||
input inputModel
|
||||
status statusBar
|
||||
|
||||
width int
|
||||
height int
|
||||
|
||||
// Chat state
|
||||
chatSessionID *string
|
||||
agentID int
|
||||
agentName string
|
||||
agents []models.AgentSummary
|
||||
parentMessageID *int
|
||||
isStreaming bool
|
||||
streamCancel context.CancelFunc
|
||||
streamCh <-chan models.StreamEvent
|
||||
citations map[int]string
|
||||
attachedFiles []models.FileDescriptorPayload
|
||||
needsRename bool
|
||||
agentStarted bool
|
||||
|
||||
// Quit state
|
||||
quitPending bool
|
||||
splashShown bool
|
||||
initInputReady bool // true once terminal init responses have passed
|
||||
}
|
||||
|
||||
// NewModel creates a new TUI model.
|
||||
func NewModel(cfg config.OnyxCliConfig) Model {
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
|
||||
return Model{
|
||||
config: cfg,
|
||||
client: client,
|
||||
viewport: newViewport(80),
|
||||
input: newInputModel(),
|
||||
status: newStatusBar(),
|
||||
agentID: cfg.DefaultAgentID,
|
||||
agentName: "Default",
|
||||
parentMessageID: &parentID,
|
||||
citations: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the model.
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return loadAgentsCmd(m.client)
|
||||
}
|
||||
|
||||
// Update handles messages.
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Filter out terminal query responses (OSC 11 background color, cursor
|
||||
// position reports, etc.) that arrive as key events with raw escape content.
|
||||
// These arrive split across multiple key events, so we use a brief window
|
||||
// after startup to swallow them all.
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && !m.initInputReady {
|
||||
// During init, drop ALL key events — they're terminal query responses
|
||||
_ = keyMsg
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.viewport.setWidth(msg.Width)
|
||||
m.status.setWidth(msg.Width)
|
||||
m.input.textInput.Width = msg.Width - 4
|
||||
if !m.splashShown {
|
||||
m.splashShown = true
|
||||
// bottomHeight = sep + input + sep + status = 4 (approx)
|
||||
viewportHeight := msg.Height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = msg.Height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
// Delay input focus to let terminal query responses flush
|
||||
return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg {
|
||||
return inputReadyMsg{}
|
||||
})
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.MouseMsg:
|
||||
switch msg.Button {
|
||||
case tea.MouseButtonWheelUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
case tea.MouseButtonWheelDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
|
||||
case submitMsg:
|
||||
return m.handleSubmit(msg.text)
|
||||
|
||||
case fileDropMsg:
|
||||
return m.handleFileDrop(msg.path)
|
||||
|
||||
case InitDoneMsg:
|
||||
return m.handleInitDone(msg)
|
||||
|
||||
case api.StreamEventMsg:
|
||||
return m.handleStreamEvent(msg)
|
||||
|
||||
case api.StreamDoneMsg:
|
||||
return m.handleStreamDone(msg)
|
||||
|
||||
case AgentsLoadedMsg:
|
||||
return m.handleAgentsLoaded(msg)
|
||||
|
||||
case SessionsLoadedMsg:
|
||||
return m.handleSessionsLoaded(msg)
|
||||
|
||||
case SessionResumedMsg:
|
||||
return m.handleSessionResumed(msg)
|
||||
|
||||
case FileUploadedMsg:
|
||||
return m.handleFileUploaded(msg)
|
||||
|
||||
case inputReadyMsg:
|
||||
m.initInputReady = true
|
||||
m.input.textInput.Focus()
|
||||
m.input.textInput.SetValue("")
|
||||
return m, m.input.textInput.Cursor.BlinkCmd()
|
||||
|
||||
case resetQuitMsg:
|
||||
m.quitPending = false
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Only forward messages to the text input after it's been focused
|
||||
if m.splashShown {
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// View renders the UI.
|
||||
// viewportHeight returns the number of visible chat rows, accounting for the
|
||||
// dynamic bottom area (separator, menu, file badges, input, status bar).
|
||||
func (m Model) viewportHeight() int {
|
||||
menuHeight := 0
|
||||
if m.input.menuVisible {
|
||||
menuHeight = len(m.input.menuItems)
|
||||
}
|
||||
fileHeight := 0
|
||||
if len(m.input.attachedFiles) > 0 {
|
||||
fileHeight = 1
|
||||
}
|
||||
h := m.height - (1 + menuHeight + fileHeight + 1 + 1 + 1)
|
||||
if h < 1 {
|
||||
return 1
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (m Model) View() string {
|
||||
if m.width == 0 || m.height == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
separator := lipgloss.NewStyle().Foreground(separatorColor).Render(
|
||||
strings.Repeat("─", m.width),
|
||||
)
|
||||
|
||||
menuView := m.input.viewMenu(m.width)
|
||||
viewportHeight := m.viewportHeight()
|
||||
|
||||
var parts []string
|
||||
parts = append(parts, m.viewport.view(viewportHeight))
|
||||
parts = append(parts, separator)
|
||||
if menuView != "" {
|
||||
parts = append(parts, menuView)
|
||||
}
|
||||
parts = append(parts, m.input.viewInput())
|
||||
parts = append(parts, separator)
|
||||
parts = append(parts, m.status.view())
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// handleKey processes keyboard input.
|
||||
func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyEscape:
|
||||
// Cancel streaming or close menu
|
||||
if m.input.menuVisible {
|
||||
m.input.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
// Dismiss picker
|
||||
if m.viewport.pickerActive {
|
||||
m.viewport.pickerActive = false
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyCtrlD:
|
||||
// If streaming, cancel first; require a fresh Ctrl+D pair to quit
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
if m.quitPending {
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.quitPending = true
|
||||
m.viewport.addInfo("Press Ctrl+D again to quit.")
|
||||
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
|
||||
return resetQuitMsg{}
|
||||
})
|
||||
|
||||
case tea.KeyCtrlO:
|
||||
m.viewport.showSources = !m.viewport.showSources
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
// If picker is active, handle selection
|
||||
if m.viewport.pickerActive && len(m.viewport.pickerItems) > 0 {
|
||||
item := m.viewport.pickerItems[m.viewport.pickerIndex]
|
||||
m.viewport.pickerActive = false
|
||||
switch m.viewport.pickerType {
|
||||
case pickerSession:
|
||||
return cmdResume(m, item.id)
|
||||
case pickerAgent:
|
||||
return cmdSelectAgent(m, item.id)
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.viewport.pickerActive {
|
||||
if m.viewport.pickerIndex > 0 {
|
||||
m.viewport.pickerIndex--
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.viewport.pickerActive {
|
||||
if m.viewport.pickerIndex < len(m.viewport.pickerItems)-1 {
|
||||
m.viewport.pickerIndex++
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.viewport.scrollUp(m.viewportHeight() / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.viewport.scrollDown(m.viewportHeight() / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Pass to input
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) handleSubmit(text string) (tea.Model, tea.Cmd) {
|
||||
if strings.HasPrefix(text, "/") {
|
||||
return handleSlashCommand(m, text)
|
||||
}
|
||||
return m.sendMessage(text)
|
||||
}
|
||||
|
||||
func (m Model) handleFileDrop(path string) (tea.Model, tea.Cmd) {
|
||||
return cmdAttach(m, path)
|
||||
}
|
||||
|
||||
func (m Model) cancelStream() (Model, tea.Cmd) {
|
||||
if m.streamCancel != nil {
|
||||
m.streamCancel()
|
||||
}
|
||||
if m.chatSessionID != nil {
|
||||
sid := *m.chatSessionID
|
||||
go m.client.StopChatSession(sid)
|
||||
}
|
||||
m, cmd := m.finishStream(nil)
|
||||
m.viewport.addInfo("Generation stopped.")
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) sendMessage(message string) (Model, tea.Cmd) {
|
||||
if m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addUserMessage(message)
|
||||
m.viewport.startAgent()
|
||||
|
||||
// Prepare file descriptors
|
||||
fileDescs := make([]models.FileDescriptorPayload, len(m.attachedFiles))
|
||||
copy(fileDescs, m.attachedFiles)
|
||||
m.attachedFiles = nil
|
||||
m.input.clearFiles()
|
||||
|
||||
m.isStreaming = true
|
||||
m.agentStarted = false
|
||||
m.citations = make(map[int]string)
|
||||
m.status.setStreaming(true)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.streamCancel = cancel
|
||||
|
||||
ch := m.client.SendMessageStream(
|
||||
ctx,
|
||||
message,
|
||||
m.chatSessionID,
|
||||
m.agentID,
|
||||
m.parentMessageID,
|
||||
fileDescs,
|
||||
)
|
||||
m.streamCh = ch
|
||||
|
||||
return m, api.WaitForStreamEvent(ch)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamEvent(msg api.StreamEventMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore stale events after cancellation
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
switch e := msg.Event.(type) {
|
||||
case models.SessionCreatedEvent:
|
||||
m.chatSessionID = &e.ChatSessionID
|
||||
m.needsRename = true
|
||||
m.status.setSession(e.ChatSessionID)
|
||||
|
||||
case models.MessageIDEvent:
|
||||
m.parentMessageID = &e.ReservedAgentMessageID
|
||||
|
||||
case models.MessageStartEvent:
|
||||
m.agentStarted = true
|
||||
|
||||
case models.MessageDeltaEvent:
|
||||
m.agentStarted = true
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.SearchStartEvent:
|
||||
if e.IsInternetSearch {
|
||||
m.viewport.addInfo("Web search…")
|
||||
} else {
|
||||
m.viewport.addInfo("Searching…")
|
||||
}
|
||||
|
||||
case models.SearchQueriesEvent:
|
||||
if len(e.Queries) > 0 {
|
||||
queries := e.Queries
|
||||
if len(queries) > 3 {
|
||||
queries = queries[:3]
|
||||
}
|
||||
parts := make([]string, len(queries))
|
||||
for i, q := range queries {
|
||||
parts[i] = "\"" + q + "\""
|
||||
}
|
||||
m.viewport.addInfo("Searching: " + strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
case models.SearchDocumentsEvent:
|
||||
count := len(e.Documents)
|
||||
suffix := "s"
|
||||
if count == 1 {
|
||||
suffix = ""
|
||||
}
|
||||
m.viewport.addInfo("Found " + strconv.Itoa(count) + " document" + suffix)
|
||||
|
||||
case models.ReasoningStartEvent:
|
||||
m.viewport.addInfo("Thinking…")
|
||||
|
||||
case models.ReasoningDeltaEvent:
|
||||
// We don't display reasoning text, just the indicator
|
||||
|
||||
case models.ReasoningDoneEvent:
|
||||
// No-op
|
||||
|
||||
case models.CitationEvent:
|
||||
m.citations[e.CitationNumber] = e.DocumentID
|
||||
|
||||
case models.ToolStartEvent:
|
||||
m.viewport.addInfo("Using " + e.ToolName + "…")
|
||||
|
||||
case models.ResearchAgentStartEvent:
|
||||
m.viewport.addInfo("Researching: " + e.ResearchTask)
|
||||
|
||||
case models.DeepResearchPlanDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.IntermediateReportDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.StopEvent:
|
||||
return m.finishStream(nil)
|
||||
|
||||
case models.ErrorEvent:
|
||||
m.viewport.addError(e.Error)
|
||||
return m.finishStream(nil)
|
||||
}
|
||||
|
||||
return m, api.WaitForStreamEvent(m.streamCh)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamDone(msg api.StreamDoneMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore if already cancelled
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
return m.finishStream(msg.Err)
|
||||
}
|
||||
|
||||
func (m Model) finishStream(err error) (Model, tea.Cmd) {
|
||||
m.viewport.finishAgent()
|
||||
if m.agentStarted && len(m.citations) > 0 {
|
||||
m.viewport.addCitations(m.citations)
|
||||
}
|
||||
m.isStreaming = false
|
||||
m.agentStarted = false
|
||||
m.status.setStreaming(false)
|
||||
if m.streamCancel != nil {
|
||||
m.streamCancel()
|
||||
}
|
||||
m.streamCancel = nil
|
||||
m.streamCh = nil
|
||||
|
||||
// Auto-rename new sessions
|
||||
if m.needsRename && m.chatSessionID != nil {
|
||||
m.needsRename = false
|
||||
sessionID := *m.chatSessionID
|
||||
client := m.client
|
||||
go func() {
|
||||
_, _ = client.RenameChatSession(sessionID, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleInitDone(msg InitDoneMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addWarning("Could not load agents. Using default.")
|
||||
} else {
|
||||
m.agents = msg.Agents
|
||||
for _, p := range m.agents {
|
||||
if p.ID == m.agentID {
|
||||
m.agentName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.status.setServer(m.config.ServerURL)
|
||||
m.status.setAgent(m.agentName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleAgentsLoaded(msg AgentsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load agents: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.agents = msg.Agents
|
||||
if len(m.agents) == 0 {
|
||||
m.viewport.addInfo("No agents available.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select an agent (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for _, p := range m.agents {
|
||||
label := fmt.Sprintf("%d: %s", p.ID, p.Name)
|
||||
if p.ID == m.agentID {
|
||||
label += " *"
|
||||
}
|
||||
desc := p.Description
|
||||
if len(desc) > 50 {
|
||||
desc = desc[:50] + "..."
|
||||
}
|
||||
if desc != "" {
|
||||
label += " - " + desc
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: strconv.Itoa(p.ID),
|
||||
label: label,
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerAgent, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionsLoaded(msg SessionsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load sessions: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
if len(msg.Sessions) == 0 {
|
||||
m.viewport.addInfo("No previous sessions found.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select a session to resume (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for i, s := range msg.Sessions {
|
||||
if i >= 15 {
|
||||
break
|
||||
}
|
||||
name := "Untitled"
|
||||
if s.Name != nil && *s.Name != "" {
|
||||
name = *s.Name
|
||||
}
|
||||
sid := s.ID
|
||||
if len(sid) > 8 {
|
||||
sid = sid[:8]
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: s.ID,
|
||||
label: sid + " " + name + " (" + s.Created + ")",
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerSession, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionResumed(msg SessionResumedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load session: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Cancel any in-progress stream before replacing the session
|
||||
if m.isStreaming {
|
||||
m, _ = m.cancelStream()
|
||||
}
|
||||
|
||||
detail := msg.Detail
|
||||
m.chatSessionID = &detail.ChatSessionID
|
||||
m.viewport.clearDisplay()
|
||||
m.status.setSession(detail.ChatSessionID)
|
||||
|
||||
if detail.AgentName != nil {
|
||||
m.agentName = *detail.AgentName
|
||||
m.status.setAgent(*detail.AgentName)
|
||||
}
|
||||
if detail.AgentID != nil {
|
||||
m.agentID = *detail.AgentID
|
||||
}
|
||||
|
||||
// Replay messages
|
||||
for _, chatMsg := range detail.Messages {
|
||||
switch chatMsg.MessageType {
|
||||
case "user":
|
||||
m.viewport.addUserMessage(chatMsg.Message)
|
||||
case "assistant":
|
||||
m.viewport.startAgent()
|
||||
m.viewport.appendToken(chatMsg.Message)
|
||||
m.viewport.finishAgent()
|
||||
}
|
||||
}
|
||||
|
||||
// Set parent to last message
|
||||
if len(detail.Messages) > 0 {
|
||||
lastID := detail.Messages[len(detail.Messages)-1].MessageID
|
||||
m.parentMessageID = &lastID
|
||||
}
|
||||
|
||||
desc := "Untitled"
|
||||
if detail.Description != nil && *detail.Description != "" {
|
||||
desc = *detail.Description
|
||||
}
|
||||
m.viewport.addInfo("Resumed session: " + desc)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleFileUploaded(msg FileUploadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Upload failed: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.attachedFiles = append(m.attachedFiles, *msg.Descriptor)
|
||||
m.input.addFile(msg.FileName)
|
||||
m.viewport.addInfo("Attached: " + msg.FileName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type inputReadyMsg struct{}
|
||||
type resetQuitMsg struct{}
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
)
|
||||
|
||||
// handleSlashCommand dispatches slash commands and returns updated model + cmd.
|
||||
func handleSlashCommand(m Model, text string) (Model, tea.Cmd) {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
command := strings.ToLower(parts[0])
|
||||
arg := ""
|
||||
if len(parts) > 1 {
|
||||
arg = parts[1]
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "/help":
|
||||
m.viewport.addInfo(helpText)
|
||||
return m, nil
|
||||
|
||||
case "/new":
|
||||
return cmdNew(m)
|
||||
|
||||
case "/agent":
|
||||
if arg != "" {
|
||||
return cmdSelectAgent(m, arg)
|
||||
}
|
||||
return cmdShowAgents(m)
|
||||
|
||||
case "/attach":
|
||||
return cmdAttach(m, arg)
|
||||
|
||||
case "/sessions", "/resume":
|
||||
if strings.TrimSpace(arg) != "" {
|
||||
return cmdResume(m, arg)
|
||||
}
|
||||
return cmdSessions(m)
|
||||
|
||||
case "/configure":
|
||||
m.viewport.addInfo("Run 'onyx-cli configure' to change connection settings.")
|
||||
return m, nil
|
||||
|
||||
case "/clear":
|
||||
return cmdNew(m)
|
||||
|
||||
case "/connectors":
|
||||
url := m.config.ServerURL + "/admin/indexing/status"
|
||||
if util.OpenBrowser(url) {
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
} else {
|
||||
m.viewport.addWarning("Failed to open browser. Visit: " + url)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "/settings":
|
||||
url := m.config.ServerURL + "/app/settings/general"
|
||||
if util.OpenBrowser(url) {
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
} else {
|
||||
m.viewport.addWarning("Failed to open browser. Visit: " + url)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "/quit":
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.viewport.addWarning(fmt.Sprintf("Unknown command: %s. Type /help for available commands.", command))
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
func cmdNew(m Model) (Model, tea.Cmd) {
|
||||
if m.isStreaming {
|
||||
m, _ = m.cancelStream()
|
||||
}
|
||||
m.chatSessionID = nil
|
||||
parentID := -1
|
||||
m.parentMessageID = &parentID
|
||||
m.needsRename = false
|
||||
m.citations = nil
|
||||
m.viewport.clearAll()
|
||||
// Re-add splash as a scrollable entry
|
||||
viewportHeight := m.viewportHeight()
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = m.height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
m.status.setSession("")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdShowAgents(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading agents...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return AgentsLoadedMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSelectAgent(m Model, idStr string) (Model, tea.Cmd) {
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(idStr))
|
||||
if err != nil {
|
||||
m.viewport.addWarning("Invalid agent ID. Use a number.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var target *models.AgentSummary
|
||||
for i := range m.agents {
|
||||
if m.agents[i].ID == pid {
|
||||
target = &m.agents[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
m.viewport.addWarning(fmt.Sprintf("Agent %d not found. Use /agent to see available agents.", pid))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.agentID = target.ID
|
||||
m.agentName = target.Name
|
||||
m.status.setAgent(target.Name)
|
||||
m.viewport.addInfo("Switched to agent: " + target.Name)
|
||||
|
||||
// Save preference
|
||||
m.config.DefaultAgentID = target.ID
|
||||
_ = config.Save(m.config)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdAttach(m Model, pathStr string) (Model, tea.Cmd) {
|
||||
if pathStr == "" {
|
||||
m.viewport.addWarning("Usage: /attach <file_path>")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Uploading " + pathStr + "...")
|
||||
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
fd, err := client.UploadFile(pathStr)
|
||||
if err != nil {
|
||||
return FileUploadedMsg{Err: err, FileName: pathStr}
|
||||
}
|
||||
return FileUploadedMsg{Descriptor: fd, FileName: pathStr}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSessions(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading sessions...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
sessions, err := client.ListChatSessions()
|
||||
return SessionsLoadedMsg{Sessions: sessions, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdResume(m Model, sessionIDStr string) (Model, tea.Cmd) {
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
// Try to find session by prefix match
|
||||
sessions, err := client.ListChatSessions()
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: err}
|
||||
}
|
||||
|
||||
var targetID string
|
||||
for _, s := range sessions {
|
||||
if strings.HasPrefix(s.ID, sessionIDStr) {
|
||||
targetID = s.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetID == "" {
|
||||
// Try as full UUID
|
||||
targetID = sessionIDStr
|
||||
}
|
||||
|
||||
detail, err := client.GetChatSession(targetID)
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: fmt.Errorf("session not found: %s", sessionIDStr)}
|
||||
}
|
||||
return SessionResumedMsg{Detail: detail}
|
||||
}
|
||||
}
|
||||
|
||||
// loadAgentsCmd returns a tea.Cmd that loads agents from the API.
|
||||
func loadAgentsCmd(client *api.Client) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return InitDoneMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package tui
|
||||
|
||||
const helpText = `Onyx CLI Commands
|
||||
|
||||
/help Show this help message
|
||||
/new Start a new chat session
|
||||
/agent List and switch agents
|
||||
/attach <path> Attach a file to next message
|
||||
/sessions Browse and resume previous sessions
|
||||
/clear Clear the chat display
|
||||
/configure Re-run connection setup
|
||||
/connectors Open connectors page in browser
|
||||
/settings Open Onyx settings in browser
|
||||
/quit Exit Onyx CLI
|
||||
|
||||
Keyboard Shortcuts
|
||||
|
||||
Enter Send message
|
||||
Escape Cancel current generation
|
||||
Ctrl+O Toggle source citations
|
||||
Ctrl+D Quit (press twice)
|
||||
Scroll Up/Down Mouse wheel or Shift+Up/Down
|
||||
Page Up/Down Scroll half page
|
||||
`
|
||||
@@ -1,242 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// slashCommand defines a slash command with its description.
|
||||
type slashCommand struct {
|
||||
command string
|
||||
description string
|
||||
}
|
||||
|
||||
var slashCommands = []slashCommand{
|
||||
{"/help", "Show help message"},
|
||||
{"/new", "Start a new chat session"},
|
||||
{"/agent", "List and switch agents"},
|
||||
{"/attach", "Attach a file to next message"},
|
||||
{"/sessions", "Browse and resume previous sessions"},
|
||||
{"/clear", "Clear the chat display"},
|
||||
{"/configure", "Re-run connection setup"},
|
||||
{"/connectors", "Open connectors in browser"},
|
||||
{"/settings", "Open settings in browser"},
|
||||
{"/quit", "Exit Onyx CLI"},
|
||||
}
|
||||
|
||||
// Commands that take arguments (filled in with trailing space on Tab/Enter).
|
||||
var argCommands = map[string]bool{
|
||||
"/attach": true,
|
||||
}
|
||||
|
||||
// inputModel manages the text input and slash command menu.
|
||||
type inputModel struct {
|
||||
textInput textinput.Model
|
||||
menuVisible bool
|
||||
menuItems []slashCommand
|
||||
menuIndex int
|
||||
attachedFiles []string
|
||||
}
|
||||
|
||||
func newInputModel() inputModel {
|
||||
ti := textinput.New()
|
||||
ti.Prompt = "" // We render our own prompt in viewInput()
|
||||
ti.Placeholder = "Send a message…"
|
||||
ti.CharLimit = 10000
|
||||
// Don't focus here — focus after first WindowSizeMsg to avoid
|
||||
// capturing terminal init escape sequences as input.
|
||||
|
||||
return inputModel{
|
||||
textInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m inputModel) update(msg tea.Msg) (inputModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) handleKey(msg tea.KeyMsg) (inputModel, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.menuVisible && m.menuIndex > 0 {
|
||||
m.menuIndex--
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyDown:
|
||||
if m.menuVisible && m.menuIndex < len(m.menuItems)-1 {
|
||||
m.menuIndex++
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyTab:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
} else {
|
||||
m.textInput.SetValue(cmd)
|
||||
m.textInput.SetCursor(len(cmd))
|
||||
}
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyEnter:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
// Execute immediately
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: cmd} }
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(m.textInput.Value())
|
||||
if text == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check for file path (drag-and-drop)
|
||||
if dropped := detectFileDrop(text); dropped != "" {
|
||||
m.textInput.SetValue("")
|
||||
return m, func() tea.Msg { return fileDropMsg{path: dropped} }
|
||||
}
|
||||
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: text} }
|
||||
|
||||
case tea.KeyEscape:
|
||||
if m.menuVisible {
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) updateMenu() inputModel {
|
||||
val := strings.TrimSpace(m.textInput.Value())
|
||||
if strings.HasPrefix(val, "/") && !strings.Contains(val, " ") {
|
||||
needle := strings.ToLower(val)
|
||||
var filtered []slashCommand
|
||||
for _, sc := range slashCommands {
|
||||
if strings.HasPrefix(sc.command, needle) {
|
||||
filtered = append(filtered, sc)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
m.menuVisible = true
|
||||
m.menuItems = filtered
|
||||
if m.menuIndex >= len(filtered) {
|
||||
m.menuIndex = 0
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *inputModel) addFile(name string) {
|
||||
m.attachedFiles = append(m.attachedFiles, name)
|
||||
}
|
||||
|
||||
func (m *inputModel) clearFiles() {
|
||||
m.attachedFiles = nil
|
||||
}
|
||||
|
||||
// submitMsg is sent when user submits text.
|
||||
type submitMsg struct {
|
||||
text string
|
||||
}
|
||||
|
||||
// fileDropMsg is sent when a file path is detected.
|
||||
type fileDropMsg struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// detectFileDrop checks if the text looks like a file path.
|
||||
func detectFileDrop(text string) string {
|
||||
cleaned := strings.Trim(text, "'\"")
|
||||
if cleaned == "" {
|
||||
return ""
|
||||
}
|
||||
// Only treat as a file drop if it looks explicitly path-like
|
||||
if !strings.HasPrefix(cleaned, "/") && !strings.HasPrefix(cleaned, "~") &&
|
||||
!strings.HasPrefix(cleaned, "./") && !strings.HasPrefix(cleaned, "../") {
|
||||
return ""
|
||||
}
|
||||
// Expand ~ to home dir
|
||||
if strings.HasPrefix(cleaned, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
cleaned = filepath.Join(home, cleaned[1:])
|
||||
}
|
||||
}
|
||||
abs, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ""
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// viewMenu renders the slash command menu.
|
||||
func (m inputModel) viewMenu(width int) string {
|
||||
if !m.menuVisible || len(m.menuItems) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for i, item := range m.menuItems {
|
||||
prefix := " "
|
||||
if i == m.menuIndex {
|
||||
prefix = "> "
|
||||
}
|
||||
line := prefix + item.command + " " + statusMsgStyle.Render(item.description)
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// viewInput renders the input line with prompt and optional file badges.
|
||||
func (m inputModel) viewInput() string {
|
||||
var parts []string
|
||||
|
||||
if len(m.attachedFiles) > 0 {
|
||||
badges := strings.Join(m.attachedFiles, "] [")
|
||||
parts = append(parts, statusMsgStyle.Render("Attached: ["+badges+"]"))
|
||||
}
|
||||
|
||||
parts = append(parts, inputPrompt+m.textInput.View())
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// InitDoneMsg signals that async initialization is complete.
|
||||
type InitDoneMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionsLoadedMsg carries loaded chat sessions.
|
||||
type SessionsLoadedMsg struct {
|
||||
Sessions []models.ChatSessionDetails
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionResumedMsg carries a loaded session detail.
|
||||
type SessionResumedMsg struct {
|
||||
Detail *models.ChatSessionDetailResponse
|
||||
Err error
|
||||
}
|
||||
|
||||
// FileUploadedMsg carries an uploaded file descriptor.
|
||||
type FileUploadedMsg struct {
|
||||
Descriptor *models.FileDescriptorPayload
|
||||
FileName string
|
||||
Err error
|
||||
}
|
||||
|
||||
// AgentsLoadedMsg carries freshly fetched agents from the API.
|
||||
type AgentsLoadedMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const onyxLogo = ` ██████╗ ███╗ ██╗██╗ ██╗██╗ ██╗
|
||||
██╔═══██╗████╗ ██║╚██╗ ██╔╝╚██╗██╔╝
|
||||
██║ ██║██╔██╗ ██║ ╚████╔╝ ╚███╔╝
|
||||
██║ ██║██║╚██╗██║ ╚██╔╝ ██╔██╗
|
||||
╚██████╔╝██║ ╚████║ ██║ ██╔╝ ██╗
|
||||
╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝`
|
||||
|
||||
const tagline = "Your terminal interface for Onyx"
|
||||
const splashHint = "Type a message to begin · /help for commands"
|
||||
|
||||
// renderSplash renders the splash screen centered for the given dimensions.
|
||||
func renderSplash(width, height int) string {
|
||||
// Render the logo as a single block (don't center individual lines)
|
||||
logo := splashStyle.Render(onyxLogo)
|
||||
|
||||
// Center tagline and hint relative to the logo block width
|
||||
logoWidth := lipgloss.Width(logo)
|
||||
tag := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
taglineStyle.Render(tagline),
|
||||
)
|
||||
hint := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
hintStyle.Render(splashHint),
|
||||
)
|
||||
|
||||
block := lipgloss.JoinVertical(lipgloss.Left, logo, "", tag, hint)
|
||||
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, block)
|
||||
}
|
||||
|
||||
// RenderSplashOnboarding renders splash for the terminal onboarding screen.
|
||||
func RenderSplashOnboarding(width, height int) string {
|
||||
// Render the logo as a styled block, then center it as a unit
|
||||
styledLogo := splashStyle.Render(onyxLogo)
|
||||
logoWidth := lipgloss.Width(styledLogo)
|
||||
logoLines := strings.Split(styledLogo, "\n")
|
||||
|
||||
logoHeight := len(logoLines)
|
||||
contentHeight := logoHeight + 2 // logo + blank + tagline
|
||||
topPad := (height - contentHeight) / 2
|
||||
if topPad < 1 {
|
||||
topPad = 1
|
||||
}
|
||||
|
||||
// Center the entire logo block horizontally
|
||||
blockPad := (width - logoWidth) / 2
|
||||
if blockPad < 0 {
|
||||
blockPad = 0
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < topPad; i++ {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
for _, line := range logoLines {
|
||||
b.WriteString(strings.Repeat(" ", blockPad))
|
||||
b.WriteString(line)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
b.WriteByte('\n')
|
||||
tagPad := (width - len(tagline)) / 2
|
||||
if tagPad < 0 {
|
||||
tagPad = 0
|
||||
}
|
||||
b.WriteString(strings.Repeat(" ", tagPad))
|
||||
b.WriteString(taglineStyle.Render(tagline))
|
||||
b.WriteByte('\n')
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// statusBar manages the footer status display.
|
||||
type statusBar struct {
|
||||
agentName string
|
||||
serverURL string
|
||||
sessionID string
|
||||
streaming bool
|
||||
width int
|
||||
}
|
||||
|
||||
func newStatusBar() statusBar {
|
||||
return statusBar{
|
||||
agentName: "Default",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusBar) setAgent(name string) { s.agentName = name }
|
||||
func (s *statusBar) setServer(url string) { s.serverURL = url }
|
||||
func (s *statusBar) setSession(id string) {
|
||||
if len(id) > 8 {
|
||||
id = id[:8]
|
||||
}
|
||||
s.sessionID = id
|
||||
}
|
||||
func (s *statusBar) setStreaming(v bool) { s.streaming = v }
|
||||
func (s *statusBar) setWidth(w int) { s.width = w }
|
||||
|
||||
func (s statusBar) view() string {
|
||||
var leftParts []string
|
||||
if s.serverURL != "" {
|
||||
leftParts = append(leftParts, s.serverURL)
|
||||
}
|
||||
name := s.agentName
|
||||
if name == "" {
|
||||
name = "Default"
|
||||
}
|
||||
leftParts = append(leftParts, name)
|
||||
left := statusBarStyle.Render(strings.Join(leftParts, " · "))
|
||||
|
||||
right := "Ctrl+D to quit"
|
||||
if s.streaming {
|
||||
right = "Esc to cancel"
|
||||
}
|
||||
rightRendered := statusBarStyle.Render(right)
|
||||
|
||||
// Fill space between left and right
|
||||
gap := s.width - lipgloss.Width(left) - lipgloss.Width(rightRendered)
|
||||
if gap < 1 {
|
||||
gap = 1
|
||||
}
|
||||
|
||||
return left + strings.Repeat(" ", gap) + rightRendered
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package tui
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
var (
|
||||
// Colors
|
||||
accentColor = lipgloss.Color("#6c8ebf")
|
||||
dimColor = lipgloss.Color("#555577")
|
||||
errorColor = lipgloss.Color("#ff5555")
|
||||
splashColor = lipgloss.Color("#7C6AEF")
|
||||
separatorColor = lipgloss.Color("#333355")
|
||||
citationColor = lipgloss.Color("#666688")
|
||||
|
||||
// Styles
|
||||
userPrefixStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
agentDot = lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("◉")
|
||||
infoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#b0b0cc"))
|
||||
dimInfoStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
statusMsgStyle = dimInfoStyle // used for slash menu descriptions, file badges
|
||||
errorStyle = lipgloss.NewStyle().Foreground(errorColor).Bold(true)
|
||||
warnStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
citationStyle = lipgloss.NewStyle().Foreground(citationColor)
|
||||
statusBarStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
inputPrompt = lipgloss.NewStyle().Foreground(accentColor).Render("❯ ")
|
||||
|
||||
splashStyle = lipgloss.NewStyle().Foreground(splashColor).Bold(true)
|
||||
taglineStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#A0A0A0"))
|
||||
hintStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
)
|
||||
@@ -1,447 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/styles"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// entryKind is the type of chat entry.
|
||||
type entryKind int
|
||||
|
||||
const (
|
||||
entryUser entryKind = iota
|
||||
entryAgent
|
||||
entryInfo
|
||||
entryError
|
||||
entryCitation
|
||||
)
|
||||
|
||||
// chatEntry is a single rendered entry in the chat history.
|
||||
type chatEntry struct {
|
||||
kind entryKind
|
||||
content string // raw content (for agent: the markdown source)
|
||||
rendered string // pre-rendered output
|
||||
citations []string // citation lines (for citation entries)
|
||||
}
|
||||
|
||||
// pickerKind distinguishes what the picker is selecting.
|
||||
type pickerKind int
|
||||
|
||||
const (
|
||||
pickerSession pickerKind = iota
|
||||
pickerAgent
|
||||
)
|
||||
|
||||
// pickerItem is a selectable item in the picker.
|
||||
type pickerItem struct {
|
||||
id string
|
||||
label string
|
||||
}
|
||||
|
||||
// viewport manages the chat display.
|
||||
type viewport struct {
|
||||
entries []chatEntry
|
||||
width int
|
||||
streaming bool
|
||||
streamBuf string
|
||||
showSources bool
|
||||
renderer *glamour.TermRenderer
|
||||
pickerItems []pickerItem
|
||||
pickerActive bool
|
||||
pickerIndex int
|
||||
pickerType pickerKind
|
||||
scrollOffset int // lines scrolled up from bottom (0 = pinned to bottom)
|
||||
lastHeight int // viewport height from last render
|
||||
}
|
||||
|
||||
// newMarkdownRenderer creates a Glamour renderer with zero left margin.
|
||||
func newMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
style := styles.DarkStyleConfig
|
||||
zero := uint(0)
|
||||
style.Document.Margin = &zero
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(style),
|
||||
glamour.WithWordWrap(width-4),
|
||||
)
|
||||
return r
|
||||
}
|
||||
|
||||
func newViewport(width int) *viewport {
|
||||
return &viewport{
|
||||
width: width,
|
||||
renderer: newMarkdownRenderer(width),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) addSplash(height int) {
|
||||
splash := renderSplash(v.width, height)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
rendered: splash,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) setWidth(w int) {
|
||||
v.width = w
|
||||
v.renderer = newMarkdownRenderer(w)
|
||||
}
|
||||
|
||||
func (v *viewport) addUserMessage(msg string) {
|
||||
rendered := "\n" + userPrefixStyle.Render("❯ ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryUser,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) startAgent() {
|
||||
v.streaming = true
|
||||
v.streamBuf = ""
|
||||
// Add a blank-line spacer entry before the agent message
|
||||
v.entries = append(v.entries, chatEntry{kind: entryInfo, rendered: ""})
|
||||
}
|
||||
|
||||
func (v *viewport) appendToken(token string) {
|
||||
v.streamBuf += token
|
||||
}
|
||||
|
||||
func (v *viewport) finishAgent() {
|
||||
if v.streamBuf == "" {
|
||||
v.streaming = false
|
||||
// Remove the blank spacer entry added by startAgent()
|
||||
if len(v.entries) > 0 && v.entries[len(v.entries)-1].kind == entryInfo && v.entries[len(v.entries)-1].rendered == "" {
|
||||
v.entries = v.entries[:len(v.entries)-1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Render markdown with Glamour (zero left margin style)
|
||||
rendered := v.renderMarkdown(v.streamBuf)
|
||||
rendered = strings.TrimLeft(rendered, "\n")
|
||||
rendered = strings.TrimRight(rendered, "\n")
|
||||
lines := strings.Split(rendered, "\n")
|
||||
// Prefix first line with dot, indent continuation lines
|
||||
if len(lines) > 0 {
|
||||
lines[0] = agentDot + " " + lines[0]
|
||||
for i := 1; i < len(lines); i++ {
|
||||
lines[i] = " " + lines[i]
|
||||
}
|
||||
}
|
||||
rendered = strings.Join(lines, "\n")
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryAgent,
|
||||
content: v.streamBuf,
|
||||
rendered: rendered,
|
||||
})
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
}
|
||||
|
||||
func (v *viewport) renderMarkdown(md string) string {
|
||||
if v.renderer == nil {
|
||||
return md
|
||||
}
|
||||
out, err := v.renderer.Render(md)
|
||||
if err != nil {
|
||||
return md
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *viewport) addInfo(msg string) {
|
||||
rendered := infoStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addWarning(msg string) {
|
||||
rendered := warnStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addError(msg string) {
|
||||
rendered := errorStyle.Render("● Error: ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addCitations(citations map[int]string) {
|
||||
if len(citations) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]int, 0, len(citations))
|
||||
for k := range citations {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Ints(keys)
|
||||
var parts []string
|
||||
for _, num := range keys {
|
||||
parts = append(parts, fmt.Sprintf("[%d] %s", num, citations[num]))
|
||||
}
|
||||
text := fmt.Sprintf("Sources (%d): %s", len(citations), strings.Join(parts, " "))
|
||||
var citLines []string
|
||||
citLines = append(citLines, text)
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryCitation,
|
||||
content: text,
|
||||
rendered: citationStyle.Render("● "+text),
|
||||
citations: citLines,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) showPicker(kind pickerKind, items []pickerItem) {
|
||||
v.pickerItems = items
|
||||
v.pickerType = kind
|
||||
v.pickerActive = true
|
||||
v.pickerIndex = 0
|
||||
}
|
||||
|
||||
func (v *viewport) maxScroll() int {
|
||||
ms := v.totalLines() - v.lastHeight
|
||||
if ms < 0 {
|
||||
return 0
|
||||
}
|
||||
return ms
|
||||
}
|
||||
|
||||
func (v *viewport) scrollUp(n int) {
|
||||
v.scrollOffset += n
|
||||
if ms := v.maxScroll(); v.scrollOffset > ms {
|
||||
v.scrollOffset = ms
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) scrollDown(n int) {
|
||||
v.scrollOffset -= n
|
||||
if v.scrollOffset < 0 {
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) clearAll() {
|
||||
v.entries = nil
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
v.pickerItems = nil
|
||||
v.pickerActive = false
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
|
||||
func (v *viewport) clearDisplay() {
|
||||
v.entries = nil
|
||||
v.scrollOffset = 0
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
}
|
||||
|
||||
// pickerTitle returns a title for the current picker kind.
|
||||
func (v *viewport) pickerTitle() string {
|
||||
switch v.pickerType {
|
||||
case pickerAgent:
|
||||
return "Select Agent"
|
||||
case pickerSession:
|
||||
return "Resume Session"
|
||||
default:
|
||||
return "Select"
|
||||
}
|
||||
}
|
||||
|
||||
// renderPicker renders the picker as a bordered overlay.
|
||||
func (v *viewport) renderPicker(width, height int) string {
|
||||
title := v.pickerTitle()
|
||||
|
||||
// Determine picker dimensions
|
||||
maxItems := len(v.pickerItems)
|
||||
panelWidth := width - 4
|
||||
if panelWidth < 30 {
|
||||
panelWidth = 30
|
||||
}
|
||||
if panelWidth > 70 {
|
||||
panelWidth = 70
|
||||
}
|
||||
innerWidth := panelWidth - 4 // border + padding
|
||||
|
||||
// Visible window of items (scroll if too many)
|
||||
maxVisible := height - 6 // room for border, title, hint
|
||||
if maxVisible < 3 {
|
||||
maxVisible = 3
|
||||
}
|
||||
if maxVisible > maxItems {
|
||||
maxVisible = maxItems
|
||||
}
|
||||
|
||||
// Calculate scroll window around current index
|
||||
startIdx := 0
|
||||
if v.pickerIndex >= maxVisible {
|
||||
startIdx = v.pickerIndex - maxVisible + 1
|
||||
}
|
||||
endIdx := startIdx + maxVisible
|
||||
if endIdx > maxItems {
|
||||
endIdx = maxItems
|
||||
startIdx = endIdx - maxVisible
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
}
|
||||
|
||||
var itemLines []string
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
item := v.pickerItems[i]
|
||||
label := item.label
|
||||
labelRunes := []rune(label)
|
||||
if len(labelRunes) > innerWidth-4 {
|
||||
label = string(labelRunes[:innerWidth-7]) + "..."
|
||||
}
|
||||
if i == v.pickerIndex {
|
||||
line := lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("> " + label)
|
||||
itemLines = append(itemLines, line)
|
||||
} else {
|
||||
itemLines = append(itemLines, " "+label)
|
||||
}
|
||||
}
|
||||
|
||||
hint := lipgloss.NewStyle().Foreground(dimColor).Render("↑↓ navigate • enter select • esc cancel")
|
||||
|
||||
body := strings.Join(itemLines, "\n") + "\n\n" + hint
|
||||
|
||||
panel := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accentColor).
|
||||
Padding(1, 2).
|
||||
Width(panelWidth).
|
||||
Render(body)
|
||||
|
||||
titleRendered := lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Bold(true).
|
||||
Render(" " + title + " ")
|
||||
|
||||
// Build top border manually to avoid ANSI-corrupted rune slicing.
|
||||
// panelWidth+2 accounts for the left and right border characters.
|
||||
borderColor := lipgloss.NewStyle().Foreground(accentColor)
|
||||
titleWidth := lipgloss.Width(titleRendered)
|
||||
rightDashes := panelWidth + 2 - 3 - titleWidth // total - "╭─" - "╮" - title
|
||||
if rightDashes < 0 {
|
||||
rightDashes = 0
|
||||
}
|
||||
topBorder := borderColor.Render("╭─") + titleRendered +
|
||||
borderColor.Render(strings.Repeat("─", rightDashes)+"╮")
|
||||
|
||||
panelLines := strings.Split(panel, "\n")
|
||||
if len(panelLines) > 0 {
|
||||
panelLines[0] = topBorder
|
||||
}
|
||||
panel = strings.Join(panelLines, "\n")
|
||||
|
||||
// Center the panel in the viewport
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, panel)
|
||||
}
|
||||
|
||||
// totalLines computes the total number of rendered content lines.
|
||||
func (v *viewport) totalLines() int {
|
||||
var lines []string
|
||||
for _, e := range v.entries {
|
||||
if e.kind == entryCitation && !v.showSources {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, e.rendered)
|
||||
}
|
||||
if v.streaming && v.streamBuf != "" {
|
||||
bufLines := strings.Split(v.streamBuf, "\n")
|
||||
if len(bufLines) > 0 {
|
||||
bufLines[0] = agentDot + " " + bufLines[0]
|
||||
for i := 1; i < len(bufLines); i++ {
|
||||
bufLines[i] = " " + bufLines[i]
|
||||
}
|
||||
}
|
||||
lines = append(lines, strings.Join(bufLines, "\n"))
|
||||
} else if v.streaming {
|
||||
lines = append(lines, agentDot+" ")
|
||||
}
|
||||
content := strings.Join(lines, "\n")
|
||||
return len(strings.Split(content, "\n"))
|
||||
}
|
||||
|
||||
// view renders the full viewport content.
|
||||
func (v *viewport) view(height int) string {
|
||||
// If picker is active, render it as an overlay
|
||||
if v.pickerActive && len(v.pickerItems) > 0 {
|
||||
return v.renderPicker(v.width, height)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
for _, e := range v.entries {
|
||||
if e.kind == entryCitation && !v.showSources {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, e.rendered)
|
||||
}
|
||||
|
||||
// Streaming buffer (plain text, not markdown)
|
||||
if v.streaming && v.streamBuf != "" {
|
||||
bufLines := strings.Split(v.streamBuf, "\n")
|
||||
if len(bufLines) > 0 {
|
||||
bufLines[0] = agentDot + " " + bufLines[0]
|
||||
for i := 1; i < len(bufLines); i++ {
|
||||
bufLines[i] = " " + bufLines[i]
|
||||
}
|
||||
}
|
||||
lines = append(lines, strings.Join(bufLines, "\n"))
|
||||
} else if v.streaming {
|
||||
lines = append(lines, agentDot+" ")
|
||||
}
|
||||
|
||||
content := strings.Join(lines, "\n")
|
||||
contentLines := strings.Split(content, "\n")
|
||||
total := len(contentLines)
|
||||
|
||||
v.lastHeight = height
|
||||
maxScroll := total - height
|
||||
if maxScroll < 0 {
|
||||
maxScroll = 0
|
||||
}
|
||||
scrollOffset := v.scrollOffset
|
||||
if scrollOffset > maxScroll {
|
||||
scrollOffset = maxScroll
|
||||
}
|
||||
|
||||
if total <= height {
|
||||
// Content fits — pad with empty lines at top to push content down
|
||||
padding := make([]string, height-total)
|
||||
for i := range padding {
|
||||
padding[i] = ""
|
||||
}
|
||||
contentLines = append(padding, contentLines...)
|
||||
} else {
|
||||
// Show a window: end is (total - scrollOffset), start is (end - height)
|
||||
end := total - scrollOffset
|
||||
start := end - height
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
contentLines = contentLines[start:end]
|
||||
}
|
||||
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stripANSI removes ANSI escape sequences for test comparisons.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
func stripANSI(s string) string {
|
||||
return ansiRegex.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
func TestAddUserMessage(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("hello world")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryUser {
|
||||
t.Errorf("expected entryUser, got %d", e.kind)
|
||||
}
|
||||
if e.content != "hello world" {
|
||||
t.Errorf("expected content 'hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "❯") {
|
||||
t.Errorf("expected rendered to contain ❯, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAndFinishAgent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
|
||||
if !v.streaming {
|
||||
t.Error("expected streaming to be true after startAgent")
|
||||
}
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 spacer entry, got %d", len(v.entries))
|
||||
}
|
||||
if v.entries[0].rendered != "" {
|
||||
t.Errorf("expected empty spacer, got %q", v.entries[0].rendered)
|
||||
}
|
||||
|
||||
v.appendToken("Hello ")
|
||||
v.appendToken("world")
|
||||
|
||||
if v.streamBuf != "Hello world" {
|
||||
t.Errorf("expected streamBuf 'Hello world', got %q", v.streamBuf)
|
||||
}
|
||||
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false after finishAgent")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after finish, got %q", v.streamBuf)
|
||||
}
|
||||
if len(v.entries) != 2 {
|
||||
t.Fatalf("expected 2 entries (spacer + agent), got %d", len(v.entries))
|
||||
}
|
||||
|
||||
e := v.entries[1]
|
||||
if e.kind != entryAgent {
|
||||
t.Errorf("expected entryAgent, got %d", e.kind)
|
||||
}
|
||||
if e.content != "Hello world" {
|
||||
t.Errorf("expected content 'Hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentNoPadding(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Test message")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
// First line should not start with plain spaces (ANSI codes are OK)
|
||||
plain := stripANSI(e.rendered)
|
||||
lines := strings.Split(plain, "\n")
|
||||
if strings.HasPrefix(lines[0], " ") {
|
||||
t.Errorf("first line should not start with spaces, got %q", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentMultiline(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Line one\n\nLine three")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
plain := stripANSI(e.rendered)
|
||||
// Glamour may merge or reformat lines; just check content is present
|
||||
if !strings.Contains(plain, "Line one") {
|
||||
t.Errorf("expected 'Line one' in rendered, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "Line three") {
|
||||
t.Errorf("expected 'Line three' in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false")
|
||||
}
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected 0 entries (spacer removed), got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddInfo(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("test info")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryInfo {
|
||||
t.Errorf("expected entryInfo, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("info should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddError(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addError("something broke")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryError {
|
||||
t.Errorf("expected entryError, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "something broke") {
|
||||
t.Errorf("expected error message in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitations(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{1: "doc-a", 2: "doc-b"})
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryCitation {
|
||||
t.Errorf("expected entryCitation, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Sources (2)") {
|
||||
t.Errorf("expected sources count in rendered, got %q", plain)
|
||||
}
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("citation should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitationsEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{})
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries for empty citations, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationVisibility(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
v.addCitations(map[int]string{1: "doc"})
|
||||
|
||||
v.showSources = false
|
||||
view := v.view(20)
|
||||
plain := stripANSI(view)
|
||||
if strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations hidden when showSources=false")
|
||||
}
|
||||
|
||||
v.showSources = true
|
||||
view = v.view(20)
|
||||
plain = stripANSI(view)
|
||||
if !strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations visible when showSources=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearAll(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.startAgent()
|
||||
v.appendToken("response")
|
||||
|
||||
v.clearAll()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearAll, got %d", len(v.entries))
|
||||
}
|
||||
if v.streaming {
|
||||
t.Error("expected streaming=false after clearAll")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after clearAll, got %q", v.streamBuf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDisplay(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.clearDisplay()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearDisplay, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewPadsShortContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
|
||||
view := v.view(10)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 10 {
|
||||
t.Errorf("expected 10 lines (padded), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewTruncatesTallContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
for i := 0; i < 20; i++ {
|
||||
v.addInfo("line")
|
||||
}
|
||||
|
||||
view := v.view(5)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 5 {
|
||||
t.Errorf("expected 5 lines (truncated), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// Package util provides shared utility functions.
|
||||
package util
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// OpenBrowser opens the given URL in the user's default browser.
|
||||
// Returns true if the browser was launched successfully.
|
||||
func OpenBrowser(url string) bool {
|
||||
var cmd *exec.Cmd
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
}
|
||||
if cmd != nil {
|
||||
if err := cmd.Start(); err == nil {
|
||||
// Reap the child process to avoid zombies.
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
// Package util provides shared utilities for the Onyx CLI.
|
||||
package util
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// Shared text styles used across the CLI.
|
||||
var (
|
||||
BoldStyle = lipgloss.NewStyle().Bold(true)
|
||||
DimStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#555577"))
|
||||
GreenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00cc66")).Bold(true)
|
||||
RedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff5555")).Bold(true)
|
||||
YellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
)
|
||||
23
cli/main.go
23
cli/main.go
@@ -1,23 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cmd.Version = version
|
||||
cmd.Commit = commit
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -138,6 +138,7 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
|
||||
@@ -52,6 +52,7 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -65,6 +65,7 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -70,6 +70,7 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user