Compare commits

..

4 Commits

Author SHA1 Message Date
Dane Urban
fe7c02a3a9 Some changes 2026-03-05 15:53:40 -08:00
Dane Urban
ac9f5a5f1d . 2026-03-04 15:50:14 -08:00
Dane Urban
5f6b348864 Package changes 2026-03-04 15:49:00 -08:00
Dane Urban
47bb69c3ea Doesn't work for docx 2026-03-04 15:45:16 -08:00
427 changed files with 3251 additions and 14353 deletions

View File

@@ -1,161 +0,0 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
```bash
onyx-cli validate-config
```
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### Validate configuration
```bash
onyx-cli validate-config
```
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
| Event Type | Description |
|------------|-------------|
| `MessageDeltaEvent` | Content token — concatenate all `content` fields for the full answer |
| `StopEvent` | Stream complete |
| `ErrorEvent` | Error with `error` message field |
| `SearchStartEvent` | Onyx started searching documents |
| `CitationEvent` | Source citation with `citation_number` and `document_id` |
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
|------|------|-------------|
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -335,6 +335,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi

View File

@@ -268,11 +268,10 @@ jobs:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -280,7 +279,6 @@ jobs:
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
@@ -592,108 +590,6 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-tests-lite:
needs: [build-web-image, build-backend-image]
name: Playwright Tests (lite)
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-lite"
- "extras=ecr-cache"
timeout-minutes: 30
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-npm-
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
MOCK_LLM_RESPONSE=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers (lite)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Run Playwright tests (lite)
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
path: ./web/output/playwright/
retention-days: 30
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
@@ -790,7 +686,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests, playwright-tests-lite]
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -119,11 +119,10 @@ repos:
]
- repo: https://github.com/golangci/golangci-lint
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
language_version: "1.26.0"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.

58
.vscode/launch.json vendored
View File

@@ -40,7 +40,19 @@
}
},
{
"name": "Celery",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
@@ -241,6 +253,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -485,6 +526,21 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability

View File

@@ -0,0 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -472,9 +471,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name,
time_last_modified_by_user=func.now(),
is_up_to_date=DISABLE_VECTOR_DB,
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -777,7 +774,8 @@ def update_user_group(
cc_pair_ids=user_group_update.cc_pair_ids,
)
if cc_pairs_updated and not DISABLE_VECTOR_DB:
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -223,15 +223,6 @@ def get_active_scim_token(
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
# Derive the IdP domain from the first synced user as a heuristic.
idp_domain: str | None = None
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
if mappings:
user = dal.get_user(mappings[0].user_id)
if user and "@" in user.email:
idp_domain = user.email.rsplit("@", 1)[1]
return ScimTokenResponse(
id=token.id,
name=token.name,
@@ -239,7 +230,6 @@ def get_active_scim_token(
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
idp_domain=idp_domain,
)

View File

@@ -365,7 +365,6 @@ class ScimTokenResponse(BaseModel):
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
idp_domain: str | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):

View File

@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -156,8 +153,3 @@ def delete_user_group(
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if DISABLE_VECTOR_DB:
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)

View File

@@ -0,0 +1,142 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector.
"""Result of extracting document IDs and hierarchy nodes from a connector."""
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
"""
raw_id_to_parent: dict[str, str | None]
doc_ids: set[str]
hierarchy_nodes: list[HierarchyNode]
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
return None
class BatchResult(BaseModel):
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> BatchResult:
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID dict so that failed-to-retrieve documents are not accidentally pruned.
ID set so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: dict[str, str | None] = {}
ids: set[str] = set()
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
ids.add(item.raw_node_id)
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids[failed_id] = None
ids.add(failed_id)
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
ids.add(item.id)
return ids, hierarchy_nodes
def extract_ids_from_runnable_connector(
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_raw_id_to_parent: dict[str, str | None] = {}
all_connector_doc_ids: set[str] = set()
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
raw_id_to_parent=all_raw_id_to_parent,
doc_ids=all_connector_doc_ids,
hierarchy_nodes=all_hierarchy_nodes,
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
super().progress(tag, amount)
def _resolve_and_update_document_parents(
db_session: Session,
redis_client: Redis,
source: DocumentSource,
raw_id_to_parent: dict[str, str | None],
) -> None:
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
each document and bulk-update the DB. Mirrors the resolution logic in
run_docfetching.py."""
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent_id in raw_id_to_parent.items():
if raw_parent_id is None:
continue
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
resolved[doc_id] = node_id if found else source_node_id
if not resolved:
return
update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
task_logger.info(
f"Pruning: resolved and updated parent hierarchy for "
f"{len(resolved)} documents (source={source.value})"
)
"""Jobs / utils for kicking off pruning tasks."""
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
all_connector_doc_ids = extraction_result.doc_ids
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
ensure_source_node_exists(redis_client, db_session, source)
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=source,
source=cc_pair.connector.source,
commit=True,
is_connector_public=is_connector_public,
)
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=source,
source=cc_pair.connector.source,
entries=cache_entries,
)
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
db_session=db_session,
redis_client=redis_client,
source=source,
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
"Pruning set collected: "

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -36,6 +36,7 @@ from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
@@ -83,6 +84,28 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
if model_provider not in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}:
return False
return any(
(
msg.message_type == MessageType.ASSISTANT
and msg.tool_calls
and len(msg.tool_calls) > 0
)
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
for msg in simple_chat_history
)
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
@@ -663,7 +686,12 @@ def run_llm_loop(
elif out_of_cycles or ran_image_gen:
# Last cycle, no tools allowed, just answer!
tool_choice = ToolChoiceOptions.NONE
final_tools = []
# Bedrock requires tool config in requests that include toolUse/toolResult history.
final_tools = (
tools
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
else []
)
else:
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools

View File

@@ -495,7 +495,14 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
# Individual worker concurrency settings
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)

View File

@@ -84,6 +84,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (

View File

@@ -943,9 +943,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
page
),
)
)
@@ -995,7 +992,6 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=page_id,
)
)

View File

@@ -781,5 +781,4 @@ def build_slim_document(
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
)

View File

@@ -902,11 +902,6 @@ class JiraConnector(
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
current_offset += 1

View File

@@ -385,7 +385,6 @@ class IndexingDocument(Document):
class SlimDocument(BaseModel):
id: str
external_access: ExternalAccess | None = None
parent_hierarchy_raw_node_id: str | None = None
class HierarchyNode(BaseModel):

View File

@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
drive_name: str,
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
return SlimDocument(
id=driveitem.id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
def _convert_sitepage_to_slim_document(
site_page: dict[str, Any],
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
return SlimDocument(
id=id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
@@ -1600,22 +1594,12 @@ class SharepointConnector(
)
)
parent_hierarchy_url: str | None = None
if drive_web_url:
parent_hierarchy_url = self._get_parent_hierarchy_url(
site_url, drive_web_url, drive_name, driveitem
)
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem,
drive_name,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
driveitem, drive_name, ctx, self.graph_client
)
)
except Exception as e:
@@ -1635,10 +1619,7 @@ class SharepointConnector(
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
site_page, ctx, self.graph_client
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:

View File

@@ -565,7 +565,6 @@ def _get_all_doc_ids(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
parent_hierarchy_raw_node_id=channel_id,
)
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
@@ -247,7 +246,6 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
is_up_to_date=DISABLE_VECTOR_DB,
time_last_modified_by_user=func.now(),
)
db_session.add(new_document_set_row)
@@ -338,8 +336,7 @@ def update_document_set(
)
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(

View File

@@ -1,7 +1,5 @@
"""CRUD operations for HierarchyNode."""
from collections import defaultdict
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
return {doc_id: parent_id for doc_id, parent_id in results}
def update_document_parent_hierarchy_nodes(
db_session: Session,
doc_parent_map: dict[str, int | None],
commit: bool = True,
) -> int:
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
Only updates rows whose current value differs from the desired value to
avoid unnecessary writes.
Args:
db_session: SQLAlchemy session
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
commit: Whether to commit the transaction
Returns:
Number of documents actually updated
"""
if not doc_parent_map:
return 0
doc_ids = list(doc_parent_map.keys())
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
by_parent: dict[int | None, list[str]] = defaultdict(list)
for doc_id, desired_parent_id in doc_parent_map.items():
current = existing.get(doc_id)
if current == desired_parent_id or doc_id not in existing:
continue
by_parent[desired_parent_id].append(doc_id)
updated = 0
for desired_parent_id, ids in by_parent.items():
db_session.query(Document).filter(Document.id.in_(ids)).update(
{Document.parent_hierarchy_node_id: desired_parent_id},
synchronize_session=False,
)
updated += len(ids)
if commit:
db_session.commit()
elif updated:
db_session.flush()
return updated
def update_hierarchy_node_permissions(
db_session: Session,
raw_node_id: str,

View File

@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified; DB is not in a valid state.")
raise RuntimeError("No search settings specified, DB is not in a valid state")
return latest_settings

View File

@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type

View File

@@ -26,10 +26,11 @@ def get_default_document_index(
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated. Updates are applied to both indices.
WARNING: In that case, get_all_document_indices should be used.
need to be updated, updates are applied to both indices.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
@@ -50,26 +51,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
@@ -100,7 +86,8 @@ def get_all_document_indices(
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks are not currently supported so we hardcode appropriate values.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
@@ -136,36 +123,13 @@ def get_all_document_indices(
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)

View File

@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
secondary_embedding_dim: int | None,
secondary_embedding_precision: EmbeddingPrecision | None,
# NOTE: We do not support large chunks right now.
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
multitenant: bool = False,
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
f"Expected {MULTI_TENANT}, got {multitenant}."
)
tenant_id = get_current_tenant_id()
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
self._real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
self._secondary_real_index: OpenSearchDocumentIndex | None = None
if self.secondary_index_name:
if secondary_embedding_dim is None or secondary_embedding_precision is None:
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
self._secondary_real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=self.secondary_index_name,
embedding_dim=secondary_embedding_dim,
embedding_precision=secondary_embedding_precision,
)
@staticmethod
def register_multitenant_indices(
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
) -> None:
self._real_index.verify_and_create_index_if_necessary(
# Only handle primary index for now, ignore secondary.
return self._real_index.verify_and_create_index_if_necessary(
primary_embedding_dim, primary_embedding_precision
)
if self.secondary_index_name:
if (
secondary_index_embedding_dim is None
or secondary_index_embedding_precision is None
):
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.verify_and_create_index_if_necessary(
secondary_index_embedding_dim, secondary_index_embedding_precision
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
# Convert IndexBatchParams to IndexingMetadata.
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
tenant_id: str, # noqa: ARG002
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
total_chunks_deleted += self._secondary_real_index.delete(
doc_id, chunk_count
)
return total_chunks_deleted
return self._real_index.delete(doc_id, chunk_count)
def update_single(
self,
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> None:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
f"Tried to update document {doc_id} with no updated fields or user fields."
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
try:
self._real_index.update([update_request])
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.update([update_request])
except NotFoundError:
logger.exception(
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "

View File

@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
persona_ids=persona_ids,
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
vespa_document_index.update([update_request])
vespa_document_index.update([update_request])
def delete_single(
self,
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
tenant_state = TenantState(
tenant_id=get_current_tenant_id(),
multitenant=MULTI_TENANT,
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
raise ValueError(
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
total_chunks_deleted = 0
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
total_chunks_deleted += vespa_document_index.delete(
document_id=doc_id, chunk_count=chunk_count
)
return total_chunks_deleted
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
def id_based_retrieval(
self,

View File

@@ -92,98 +92,6 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
return [prompt.model_dump(exclude_none=True)]
def _normalize_content(raw: Any) -> str:
"""Normalize a message content field to a plain string.
Content can be a string, None, or a list of content-block dicts
(e.g. [{"type": "text", "text": "..."}]).
"""
if raw is None:
return ""
if isinstance(raw, str):
return raw
if isinstance(raw, list):
return "\n".join(
block.get("text", "") if isinstance(block, dict) else str(block)
for block in raw
)
return str(raw)
def _strip_tool_content_from_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Convert tool-related messages to plain text.
Bedrock's Converse API requires toolConfig when messages contain
toolUse/toolResult content blocks. When no tools are provided for the
current request, we must convert any tool-related history into plain text
to avoid the "toolConfig field must be defined" error.
This is the same approach used by _OllamaHistoryMessageFormatter.
"""
result: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
tool_calls = msg.get("tool_calls")
if role == "assistant" and tool_calls:
# Convert structured tool calls to text representation
tool_call_lines = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "unknown")
args = func.get("arguments", "{}")
tc_id = tc.get("id", "")
tool_call_lines.append(
f"[Tool Call] name={name} id={tc_id} args={args}"
)
existing_content = _normalize_content(msg.get("content"))
parts = (
[existing_content] + tool_call_lines
if existing_content
else tool_call_lines
)
new_msg = {
"role": "assistant",
"content": "\n".join(parts),
}
result.append(new_msg)
elif role == "tool":
# Convert tool response to user message with text content
tool_call_id = msg.get("tool_call_id", "")
content = _normalize_content(msg.get("content"))
tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}"
# Merge into previous user message if it is also a converted
# tool result to avoid consecutive user messages (Bedrock requires
# strict user/assistant alternation).
if (
result
and result[-1]["role"] == "user"
and "[Tool Result]" in result[-1].get("content", "")
):
result[-1]["content"] += "\n\n" + tool_result_text
else:
result.append({"role": "user", "content": tool_result_text})
else:
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
if msg.get("role") == "tool":
return True
if msg.get("role") == "assistant" and msg.get("tool_calls"):
return True
return False
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -496,30 +404,13 @@ class LitellmLLM(LLM):
else nullcontext()
)
with env_ctx:
messages = _prompt_to_dicts(prompt)
# Bedrock's Converse API requires toolConfig when messages
# contain toolUse/toolResult content blocks. When no tools are
# provided for this request but the history contains tool
# content from previous turns, strip it to plain text.
is_bedrock = self._model_provider in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}
if (
is_bedrock
and not tools
and _messages_contain_tool_content(messages)
):
messages = _strip_tool_content_from_messages(messages)
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,

View File

@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
@@ -146,11 +146,6 @@ class SlackRenderer(HTMLRenderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
@@ -223,48 +218,5 @@ class SlackRenderer(HTMLRenderer):
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n\n"

View File

@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"version": "4.11.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"

View File

@@ -11,7 +11,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import delete_document_set as db_delete_document_set
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import insert_document_set
@@ -143,10 +142,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if DISABLE_VECTOR_DB:
db_session.refresh(document_set)
db_delete_document_set(document_set, db_session)
else:
if not DISABLE_VECTOR_DB:
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -10,8 +11,6 @@ from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.models import User
from onyx.db.search_settings import get_all_search_settings
from onyx.db.search_settings import get_current_db_embedding_provider
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.indexing.models import EmbeddingModelDetail
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
@@ -60,7 +59,7 @@ def test_embedding_configuration(
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("", response_model=list[EmbeddingModelDetail])
@@ -94,9 +93,8 @@ def delete_embedding_provider(
embedding_provider is not None
and provider_type == embedding_provider.provider_type
):
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"You can't delete a currently active model",
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, provider_type=provider_type)

View File

@@ -11,6 +11,7 @@ from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import ValidationError
from sqlalchemy.orm import Session
@@ -37,8 +38,6 @@ from onyx.db.llm import upsert_llm_provider
from onyx.db.llm import validate_persona_ids_exist
from onyx.db.models import User
from onyx.db.persona import user_can_access_persona
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
@@ -187,7 +186,7 @@ def _validate_llm_provider_change(
Only enforced in MULTI_TENANT mode.
Raises:
OnyxError: If api_base or custom_config changed without changing API key
HTTPException: If api_base or custom_config changed without changing API key
"""
if not MULTI_TENANT or api_key_changed:
return
@@ -201,9 +200,9 @@ def _validate_llm_provider_change(
)
if api_base_changed or custom_config_changed:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API base and/or custom config cannot be changed without changing the API key",
raise HTTPException(
status_code=400,
detail="API base and/or custom config cannot be changed without changing the API key",
)
@@ -223,7 +222,7 @@ def fetch_llm_provider_options(
for well_known_llm in well_known_llms:
if well_known_llm.name == provider_name:
return well_known_llm
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found")
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
@admin_router.post("/test")
@@ -282,7 +281,7 @@ def test_llm_configuration(
error_msg = test_llm(llm)
if error_msg:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.post("/test/default")
@@ -293,11 +292,11 @@ def test_default_provider(
llm = get_default_llm()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
error = test_llm(llm)
if error:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error))
raise HTTPException(status_code=400, detail=str(error))
@admin_router.get("/provider")
@@ -363,31 +362,35 @@ def put_llm_provider(
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Renaming providers is not currently supported",
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"Provider with name={llm_provider_upsert_request.name} already exists",
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists",
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist",
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -412,9 +415,9 @@ def put_llm_provider(
db_session, persona_ids
)
if missing_personas:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
raise HTTPException(
status_code=400,
detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}",
)
# Remove duplicates while preserving order
seen: set[int] = set()
@@ -470,7 +473,7 @@ def put_llm_provider(
return result
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")
@@ -480,19 +483,19 @@ def delete_llm_provider(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Cannot delete the default LLM provider",
)
try:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise HTTPException(
status_code=400,
detail="Cannot delete the default LLM provider",
)
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/default")
@@ -532,9 +535,9 @@ def get_auto_config(
"""
config = fetch_llm_recommendations_from_github()
if not config:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Failed to fetch configuration from GitHub",
raise HTTPException(
status_code=502,
detail="Failed to fetch configuration from GitHub",
)
return config.model_dump()
@@ -691,13 +694,13 @@ def list_llm_providers_for_persona(
persona = fetch_persona_with_groups(db_session, persona_id)
if not persona:
raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found")
raise HTTPException(status_code=404, detail="Persona not found")
# Verify user has access to this persona
if not user_can_access_persona(db_session, persona_id, user, get_editable=False):
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You don't have access to this assistant",
raise HTTPException(
status_code=403,
detail="You don't have access to this assistant",
)
is_admin = user.role == UserRole.ADMIN
@@ -851,9 +854,9 @@ def get_bedrock_available_models(
try:
bedrock = session.client("bedrock")
except Exception as e:
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
raise HTTPException(
status_code=400,
detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.",
)
# Build model info dict from foundation models (modelId -> metadata)
@@ -972,14 +975,14 @@ def get_bedrock_available_models(
return results
except (ClientError, NoCredentialsError, BotoCoreError) as e:
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Failed to connect to AWS Bedrock: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to connect to AWS Bedrock: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Unexpected error fetching Bedrock models: {e}",
raise HTTPException(
status_code=500,
detail=f"Unexpected error fetching Bedrock models: {e}",
)
@@ -991,9 +994,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]:
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch Ollama models: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to fetch Ollama models: {e}",
)
models = response_json.get("models", [])
@@ -1010,9 +1013,9 @@ def get_ollama_available_models(
cleaned_api_base = request.api_base.strip().rstrip("/")
if not cleaned_api_base:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API base URL is required to fetch Ollama models.",
raise HTTPException(
status_code=400,
detail="API base URL is required to fetch Ollama models.",
)
# NOTE: most people run Ollama locally, so we don't disallow internal URLs
@@ -1021,9 +1024,9 @@ def get_ollama_available_models(
# with the same response format
model_names = _get_ollama_available_model_names(cleaned_api_base)
if not model_names:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Ollama server",
raise HTTPException(
status_code=400,
detail="No models found from your Ollama server",
)
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
@@ -1125,9 +1128,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
response.raise_for_status()
return response.json()
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch OpenRouter models: {e}",
raise HTTPException(
status_code=400,
detail=f"Failed to fetch OpenRouter models: {e}",
)
@@ -1148,9 +1151,9 @@ def get_openrouter_available_models(
data = response_json.get("data", [])
if not isinstance(data, list) or len(data) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenRouter endpoint",
raise HTTPException(
status_code=400,
detail="No models found from your OpenRouter endpoint",
)
results: list[OpenRouterFinalModelResponse] = []
@@ -1185,9 +1188,8 @@ def get_openrouter_available_models(
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenRouter",
raise HTTPException(
status_code=400, detail="No compatible models found from OpenRouter"
)
sorted_results = sorted(results, key=lambda m: m.name.lower())

View File

@@ -6,11 +6,8 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
@@ -18,25 +15,20 @@ from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import delete_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.factory import get_default_document_index
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import update_unstructured_api_key
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
@@ -49,99 +41,110 @@ def set_new_search_settings(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
Creates a new SearchSettings row and cancels the previous secondary indexing
if any exists.
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments.
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider.
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
validate_contextual_rag_model(
provider_name=search_settings_new.contextual_rag_llm_provider,
model_name=search_settings_new.contextual_rag_llm_name,
db_session=db_session,
# TODO(andrei): Re-enable.
# NOTE Enable integration external dependency tests in test_search_settings.py
# when this is reenabled. They are currently skipped
logger.error("Setting new search settings is temporarily disabled.")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Setting new search settings is temporarily disabled.",
)
# if search_settings_new.index_name:
# logger.warning("Index name was specified by request, this is not suggested")
search_settings = get_current_search_settings(db_session)
# # Disallow contextual RAG for cloud deployments
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Contextual RAG disabled in Onyx Cloud",
# )
if search_settings_new.index_name is None:
# We define index name here.
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
# # Validate cloud provider exists or create new LiteLLM provider
# if search_settings_new.provider_type is not None:
# cloud_provider = get_embedding_provider_from_provider_type(
# db_session, provider_type=search_settings_new.provider_type
# )
secondary_search_settings = get_secondary_search_settings(db_session)
# if cloud_provider is None:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
# )
if secondary_search_settings:
# Cancel any background indexing jobs.
expire_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# validate_contextual_rag_model(
# provider_name=search_settings_new.contextual_rag_llm_provider,
# model_name=search_settings_new.contextual_rag_llm_name,
# db_session=db_session,
# )
# Mark previous model as a past model directly.
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
# search_settings = get_current_search_settings(db_session)
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# if search_settings_new.index_name is None:
# # We define index name here
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
# if (
# search_settings_new.model_name == search_settings.model_name
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
# ):
# index_name += ALT_INDEX_SUFFIX
# search_values = search_settings_new.model_dump()
# search_values["index_name"] = index_name
# new_search_settings_request = SavedSearchSettings(**search_values)
# else:
# new_search_settings_request = SavedSearchSettings(
# **search_settings_new.model_dump()
# )
# Ensure the document indices have the new index immediately.
document_indices = get_all_document_indices(search_settings, new_search_settings)
for document_index in document_indices:
document_index.ensure_indices_exist(
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# secondary_search_settings = get_secondary_search_settings(db_session)
# Pause index attempts for the currently in-use index to preserve resources.
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
# if secondary_search_settings:
# # Cancel any background indexing jobs
# expire_index_attempts(
# search_settings_id=secondary_search_settings.id, db_session=db_session
# )
db_session.commit()
return IdReturn(id=new_search_settings.id)
# # Mark previous model as a past model directly
# update_search_settings_status(
# search_settings=secondary_search_settings,
# new_status=IndexModelStatus.PAST,
# db_session=db_session,
# )
# new_search_settings = create_search_settings(
# search_settings=new_search_settings_request, db_session=db_session
# )
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
# primary_embedding_precision=search_settings.embedding_precision,
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
# )
# # Pause index attempts for the currently in use index to preserve resources
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# expire_index_attempts(
# search_settings_id=search_settings.id, db_session=db_session
# )
# for cc_pair in get_connector_credential_pairs(db_session):
# resync_cc_pair(
# cc_pair=cc_pair,
# search_settings_id=new_search_settings.id,
# db_session=db_session,
# )
# db_session.commit()
# return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])

View File

@@ -60,11 +60,9 @@ class Settings(BaseModel):
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import json
import time
from collections.abc import Generator
@@ -86,19 +84,6 @@ class CodeInterpreterClient:
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
self._closed = False
def __enter__(self) -> CodeInterpreterClient:
return self
def __exit__(self, *args: object) -> None:
self.close()
def close(self) -> None:
if self._closed:
return
self.session.close()
self._closed = True
def _build_payload(
self,
@@ -118,13 +103,7 @@ class CodeInterpreterClient:
return payload
def health(self, use_cache: bool = False) -> bool:
"""Check if the Code Interpreter service is healthy
Args:
use_cache: When True, return a cached result if available and
within the TTL window. The cache is always populated
after a live request regardless of this flag.
"""
"""Check if the Code Interpreter service is healthy"""
if use_cache:
cached = _health_cache.get(self.base_url)
if cached is not None:
@@ -192,11 +171,8 @@ class CodeInterpreterClient:
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
try:
response.raise_for_status()
yield from self._parse_sse(response)
finally:
response.close()
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response

View File

@@ -111,8 +111,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
if not server.server_enabled:
return False
with CodeInterpreterClient() as client:
return client.health(use_cache=True)
client = CodeInterpreterClient()
return client.health(use_cache=True)
def tool_definition(self) -> dict:
return {
@@ -176,203 +176,196 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
)
)
# Create Code Interpreter client — context manager ensures
# session.close() is called on every exit path.
with CodeInterpreterClient() as client:
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
logger.info(f"Staged file for Python execution: {file_name}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
# Create Code Interpreter client
client = CodeInterpreterClient()
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
logger.debug(f"Executing code: {code}")
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=(
event.data if event.stream == "stdout" else ""
),
stderr=(
event.data if event.stream == "stderr" else ""
),
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
logger.info(f"Staged file for Python execution: {file_name}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
try:
logger.debug(f"Executing code: {code}")
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file "
f"{workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated "
f"file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged "
f"file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=(None if result_event.exit_code == 0 else truncated_stderr),
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
# Emit error delta
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file {workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
obj=PythonToolDelta(file_ids=generated_file_ids),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)

View File

@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
# typing-inspect
nest-asyncio==1.6.0
# via onyx
nltk==3.9.3
nltk==3.9.1
# via unstructured
numpy==2.4.1
# via

View File

@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.3
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via

View File

@@ -16,6 +16,10 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
def run_jobs() -> None:
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
use_lightweight = True
# command setup
cmd_worker_primary = [
"celery",
"-A",
@@ -70,48 +74,6 @@ def run_jobs() -> None:
"--queues=connector_doc_fetching",
]
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,user_file_delete",
]
cmd_beat = [
"celery",
"-A",
@@ -120,31 +82,144 @@ def run_jobs() -> None:
"--loglevel=INFO",
]
all_workers = [
("PRIMARY", cmd_worker_primary),
("LIGHT", cmd_worker_light),
("DOCPROCESSING", cmd_worker_docprocessing),
("DOCFETCHING", cmd_worker_docfetching),
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
("BEAT", cmd_beat),
]
# Prepare background worker commands based on mode
if use_lightweight:
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
cmd_worker_background = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
]
background_workers = [("BACKGROUND", cmd_worker_background)]
else:
print("Starting workers in STANDARD mode (separate background workers)")
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
]
background_workers = [
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
]
processes = []
for name, cmd in all_workers:
# spawn processes
worker_primary_process = subprocess.Popen(
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_light_process = subprocess.Popen(
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_docprocessing_process = subprocess.Popen(
cmd_worker_docprocessing,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Spawn background worker processes based on mode
background_processes = []
for name, cmd in background_workers:
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
processes.append((name, process))
background_processes.append((name, process))
threads = []
for name, process in processes:
# monitor threads
worker_primary_thread = threading.Thread(
target=monitor_process, args=("PRIMARY", worker_primary_process)
)
worker_light_thread = threading.Thread(
target=monitor_process, args=("LIGHT", worker_light_process)
)
worker_docprocessing_thread = threading.Thread(
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
# Create monitor threads for background workers
background_threads = []
for name, process in background_processes:
thread = threading.Thread(target=monitor_process, args=(name, process))
threads.append(thread)
background_threads.append(thread)
# Start all threads
worker_primary_thread.start()
worker_light_thread.start()
worker_docprocessing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
for thread in background_threads:
thread.start()
for thread in threads:
# Wait for all threads
worker_primary_thread.join()
worker_light_thread.join()
worker_docprocessing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()
for thread in background_threads:
thread.join()

View File

@@ -1,20 +1,10 @@
#!/bin/bash
set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
stop_and_remove_containers() {
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
}
cleanup() {
echo "Error occurred. Cleaning up..."
stop_and_remove_containers
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -22,26 +12,16 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
# Usage of the script with optional volume arguments
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
# [minio_volume] [--keep-opensearch-data]
KEEP_OPENSEARCH_DATA=false
POSITIONAL_ARGS=()
for arg in "$@"; do
if [[ "$arg" == "--keep-opensearch-data" ]]; then
KEEP_OPENSEARCH_DATA=true
else
POSITIONAL_ARGS+=("$arg")
fi
done
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
VESPA_VOLUME=${1:-""} # Default is empty if not provided
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
REDIS_VOLUME=${3:-""} # Default is empty if not provided
MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
stop_and_remove_containers
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -59,29 +39,6 @@ else
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
fi
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
# .vscode/.env so existing dev setups that stored it there aren't silently
# broken.
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
set -a
# shellcheck source=/dev/null
source "$VSCODE_ENV"
set +a
fi
# Start the OpenSearch container using the same service from docker-compose that
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
# restarts, else the volume is deleted so the container starts fresh.
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
echo "Deleting opensearch-data volume..."
docker volume rm onyx_opensearch-data 2>/dev/null || true
fi
echo "Starting OpenSearch container..."
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
# Start the Redis container with optional volume
echo "Starting Redis container..."
if [[ -n "$REDIS_VOLUME" ]]; then
@@ -103,6 +60,7 @@ echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
cd "$PARENT_DIR"

View File

@@ -0,0 +1,10 @@
#!/bin/bash
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
source "$(dirname "$0")/../../.vscode/.env"
cd "$(dirname "$0")/../../deployment/docker_compose"
# Start OpenSearch.
echo "Forcefully starting fresh OpenSearch container..."
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch

View File

@@ -1,5 +1,23 @@
#!/bin/sh
# Entrypoint script for supervisord
# Entrypoint script for supervisord that sets environment variables
# for controlling which celery workers to start
# Default to lightweight mode if not set
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
fi
# Set the complementary variable for supervisord
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
export USE_SEPARATE_BACKGROUND_WORKERS="false"
else
export USE_SEPARATE_BACKGROUND_WORKERS="true"
fi
echo "Worker mode configuration:"
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
# Launch supervisord with environment variables available
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf

View File

@@ -39,6 +39,7 @@ autorestart=true
startsecs=10
stopasgroup=true
# Standard mode: Light worker for fast operations
# NOTE: only allowing configuration here and not in the other celery workers,
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
# user group syncing, deletion, etc.)
@@ -53,7 +54,26 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Lightweight mode: single consolidated background worker
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
[program:celery_worker_background]
command=celery -A onyx.background.celery.versioned_apps.background worker
--loglevel=INFO
--hostname=background@%%n
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
stdout_logfile=/var/log/celery_worker_background.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
# Standard mode: separate workers for different background tasks
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
[program:celery_worker_heavy]
command=celery -A onyx.background.celery.versioned_apps.heavy worker
--loglevel=INFO
@@ -65,7 +85,9 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document processing worker
[program:celery_worker_docprocessing]
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
--loglevel=INFO
@@ -77,6 +99,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_user_file_processing]
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
@@ -89,7 +112,9 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document fetching worker
[program:celery_worker_docfetching]
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
@@ -101,6 +126,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
@@ -113,6 +139,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Job scheduler for periodic tasks
@@ -170,6 +197,7 @@ command=tail -qF
/var/log/celery_beat.log
/var/log/celery_worker_primary.log
/var/log/celery_worker_light.log
/var/log/celery_worker_background.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_docprocessing.log
/var/log/celery_worker_monitoring.log

View File

@@ -5,8 +5,6 @@ Verifies that:
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
3. Upserting is idempotent (running twice doesn't duplicate nodes)
4. Document-to-hierarchy-node linkage is updated during pruning
5. link_hierarchy_nodes_to_documents links nodes that are also documents
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
combined with a real PostgreSQL database for verifying persistence.
@@ -29,13 +27,9 @@ from onyx.db.enums import HierarchyNodeType
from onyx.db.hierarchy import ensure_source_node_exists
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import Document as DbDocument
from onyx.db.models import HierarchyNode as DBHierarchyNode
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.kg.models import KGStage
# ---------------------------------------------------------------------------
# Constants
@@ -95,18 +89,8 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
]
DOC_PARENT_MAP = {
"msg-001": CHANNEL_A_ID,
"msg-002": CHANNEL_A_ID,
"msg-003": CHANNEL_B_ID,
}
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
return [
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
for doc_id in SLIM_DOC_IDS
]
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
@@ -142,31 +126,14 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
# ---------------------------------------------------------------------------
def _cleanup_test_data(db_session: Session) -> None:
"""Remove all test hierarchy nodes and documents to isolate tests."""
for doc_id in SLIM_DOC_IDS:
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == TEST_SOURCE
).delete()
db_session.commit()
def _create_test_documents(db_session: Session) -> list[DbDocument]:
"""Insert minimal Document rows for our test doc IDs."""
docs = []
for doc_id in SLIM_DOC_IDS:
doc = DbDocument(
id=doc_id,
semantic_id=doc_id,
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
docs.append(doc)
db_session.commit()
return docs
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -180,14 +147,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
result = extract_ids_from_runnable_connector(connector, callback=None)
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
expected_ids = {
CHANNEL_A_ID,
CHANNEL_B_ID,
CHANNEL_C_ID,
*SLIM_DOC_IDS,
}
assert result.raw_id_to_parent.keys() == expected_ids
assert result.doc_ids == expected_ids
# Hierarchy nodes should be the 3 channels
assert len(result.hierarchy_nodes) == 3
@@ -198,7 +165,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
then verify the DB state (node count, parent relationships, permissions)."""
_cleanup_test_data(db_session)
_cleanup_test_hierarchy_nodes(db_session)
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -263,7 +230,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
) -> None:
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
marked is_public=True regardless of their external_access settings."""
_cleanup_test_data(db_session)
_cleanup_test_hierarchy_nodes(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -290,7 +257,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
"""Upserting the same hierarchy nodes twice must not create duplicates.
The second call should update existing rows in place."""
_cleanup_test_data(db_session)
_cleanup_test_hierarchy_nodes(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -328,7 +295,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
"""Upserting a hierarchy node with changed fields should update the existing row."""
_cleanup_test_data(db_session)
_cleanup_test_hierarchy_nodes(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -375,193 +342,3 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
assert db_node.is_public is True
assert db_node.external_user_emails is not None
assert set(db_node.external_user_emails) == {"new_user@example.com"}
# ---------------------------------------------------------------------------
# Document-to-hierarchy-node linkage tests
# ---------------------------------------------------------------------------
def test_extraction_preserves_parent_hierarchy_raw_node_id(
db_session: Session, # noqa: ARG001
) -> None:
"""extract_ids_from_runnable_connector should carry the
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
connector = MockSlimConnectorWithPermSync()
result = extract_ids_from_runnable_connector(connector, callback=None)
for doc_id, expected_parent in DOC_PARENT_MAP.items():
assert (
result.raw_id_to_parent[doc_id] == expected_parent
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
# Hierarchy node entries have None parent (they aren't documents)
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
assert result.raw_id_to_parent[channel_id] is None
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
"""update_document_parent_hierarchy_nodes should set
Document.parent_hierarchy_node_id for each document in the mapping."""
_cleanup_test_data(db_session)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
# Create documents with no parent set
docs = _create_test_documents(db_session)
for doc in docs:
assert doc.parent_hierarchy_node_id is None
# Build resolved map (same logic as _resolve_and_update_document_parents)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent in DOC_PARENT_MAP.items():
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert updated == len(SLIM_DOC_IDS)
# Verify each document now points to the correct hierarchy node
db_session.expire_all()
for doc_id, raw_parent in DOC_PARENT_MAP.items():
tmp_doc = db_session.get(DbDocument, doc_id)
assert tmp_doc is not None
doc = tmp_doc
expected_node_id = node_id_by_raw[raw_parent]
assert (
doc.parent_hierarchy_node_id == expected_node_id
), f"Document {doc_id} should point to node for {raw_parent}"
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
"""Running update_document_parent_hierarchy_nodes a second time with the
same mapping should update zero rows."""
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
_create_test_documents(db_session)
resolved: dict[str, int | None] = {
doc_id: node_id_by_raw[raw_parent]
for doc_id, raw_parent in DOC_PARENT_MAP.items()
}
first_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert first_updated == len(SLIM_DOC_IDS)
second_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert second_updated == 0
def test_link_hierarchy_nodes_to_documents_for_confluence(
db_session: Session,
) -> None:
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
when a hierarchy node's raw_node_id matches a document ID."""
_cleanup_test_data(db_session)
confluence_source = DocumentSource.CONFLUENCE
# Clean up any existing Confluence hierarchy nodes
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
ensure_source_node_exists(db_session, confluence_source, commit=True)
# Create a hierarchy node whose raw_node_id matches a document ID
page_node_id = "confluence-page-123"
nodes = [
PydanticHierarchyNode(
raw_node_id=page_node_id,
raw_parent_id=None,
display_name="Test Page",
link="https://wiki.example.com/page/123",
node_type=HierarchyNodeType.PAGE,
),
]
upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=nodes,
source=confluence_source,
commit=True,
is_connector_public=False,
)
# Verify the node exists but has no document_id yet
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id is None
# Create a document with the same ID as the hierarchy node
doc = DbDocument(
id=page_node_id,
semantic_id="Test Page",
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
db_session.commit()
# Link nodes to documents
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=[page_node_id],
source=confluence_source,
commit=True,
)
assert linked == 1
# Verify the hierarchy node now has document_id set
db_session.expire_all()
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id == page_node_id
# Cleanup
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
db_session: Session,
) -> None:
"""link_hierarchy_nodes_to_documents should return 0 for sources that
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=SLIM_DOC_IDS,
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
commit=False,
)
assert linked == 0

View File

@@ -11,6 +11,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
@@ -19,8 +20,6 @@ from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import UserRole
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.server.manage.llm.api import (
@@ -123,16 +122,16 @@ class TestLLMConfigurationEndpoint:
finally:
db_session.rollback()
def test_failed_llm_test_raises_onyx_error(
def test_failed_llm_test_raises_http_exception(
self,
db_session: Session,
provider_name: str, # noqa: ARG002
) -> None:
"""
Test that a failed LLM test raises an OnyxError with VALIDATION_ERROR.
Test that a failed LLM test raises an HTTPException with status 400.
When test_llm returns an error message, the endpoint should raise
an OnyxError with the error details.
an HTTPException with the error details.
"""
error_message = "Invalid API key: Authentication failed"
@@ -144,7 +143,7 @@ class TestLLMConfigurationEndpoint:
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
):
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
provider=LlmProviderNames.OPENAI,
@@ -157,8 +156,9 @@ class TestLLMConfigurationEndpoint:
db_session=db_session,
)
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.message == error_message
# Verify the exception details
assert exc_info.value.status_code == 400
assert exc_info.value.detail == error_message
finally:
db_session.rollback()
@@ -536,11 +536,11 @@ class TestDefaultProviderEndpoint:
remove_llm_provider(db_session, provider.id)
# Now run_test_default_provider should fail
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
run_test_default_provider(_=_create_mock_admin())
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert "No LLM Provider setup" in exc_info.value.message
assert exc_info.value.status_code == 400
assert "No LLM Provider setup" in exc_info.value.detail
finally:
db_session.rollback()
@@ -581,11 +581,11 @@ class TestDefaultProviderEndpoint:
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure
):
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
run_test_default_provider(_=_create_mock_admin())
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.message == error_message
assert exc_info.value.status_code == 400
assert exc_info.value.detail == error_message
finally:
db_session.rollback()

View File

@@ -16,14 +16,13 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import UserRole
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.api import _mask_string
from onyx.server.manage.llm.api import put_llm_provider
@@ -101,7 +100,7 @@ class TestLLMProviderChanges:
api_base="https://attacker.example.com",
)
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
put_llm_provider(
llm_provider_upsert_request=update_request,
is_creation=False,
@@ -109,9 +108,9 @@ class TestLLMProviderChanges:
db_session=db_session,
)
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.status_code == 400
assert "cannot be changed without changing the API key" in str(
exc_info.value.message
exc_info.value.detail
)
finally:
_cleanup_provider(db_session, provider_name)
@@ -237,7 +236,7 @@ class TestLLMProviderChanges:
api_base=None,
)
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
put_llm_provider(
llm_provider_upsert_request=update_request,
is_creation=False,
@@ -245,9 +244,9 @@ class TestLLMProviderChanges:
db_session=db_session,
)
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.status_code == 400
assert "cannot be changed without changing the API key" in str(
exc_info.value.message
exc_info.value.detail
)
finally:
_cleanup_provider(db_session, provider_name)
@@ -340,7 +339,7 @@ class TestLLMProviderChanges:
custom_config_changed=True,
)
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
put_llm_provider(
llm_provider_upsert_request=update_request,
is_creation=False,
@@ -348,9 +347,9 @@ class TestLLMProviderChanges:
db_session=db_session,
)
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.status_code == 400
assert "cannot be changed without changing the API key" in str(
exc_info.value.message
exc_info.value.detail
)
finally:
_cleanup_provider(db_session, provider_name)
@@ -376,7 +375,7 @@ class TestLLMProviderChanges:
custom_config_changed=True,
)
with pytest.raises(OnyxError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
put_llm_provider(
llm_provider_upsert_request=update_request,
is_creation=False,
@@ -384,9 +383,9 @@ class TestLLMProviderChanges:
db_session=db_session,
)
assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR
assert exc_info.value.status_code == 400
assert "cannot be changed without changing the API key" in str(
exc_info.value.message
exc_info.value.detail
)
finally:
_cleanup_provider(db_session, provider_name)

View File

@@ -11,7 +11,6 @@ from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_default_contextual_rag_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
@@ -38,8 +37,6 @@ def _create_llm_provider_and_model(
model_name: str,
) -> None:
"""Insert an LLM provider with a single visible model configuration."""
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
return
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
@@ -149,8 +146,8 @@ def baseline_search_settings(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -158,7 +155,6 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -200,8 +196,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -209,7 +205,6 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -271,7 +266,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
)
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -279,7 +274,6 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:

View File

@@ -427,7 +427,7 @@ def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG
headers=admin_user.headers,
)
assert delete_response.status_code == 400
assert "Cannot delete the default LLM provider" in delete_response.json()["message"]
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
# Verify provider still exists
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
@@ -673,8 +673,8 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
headers=admin_user.headers,
json=base_payload,
)
assert response.status_code == 409
assert "already exists" in response.json()["message"]
assert response.status_code == 400
assert "already exists" in response.json()["detail"]
def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
@@ -711,7 +711,7 @@ def test_rename_provider_rejected(reset: None) -> None: # noqa: ARG001
json=update_payload,
)
assert response.status_code == 400
assert "not currently supported" in response.json()["message"]
assert "not currently supported" in response.json()["detail"]
# Verify no duplicate was created — only the original provider should exist
provider = _get_provider_by_id(admin_user, provider_id)

View File

@@ -69,7 +69,7 @@ def test_unauthorized_persona_access_returns_403(
# Should return 403 Forbidden
assert response.status_code == 403
assert "don't have access to this assistant" in response.json()["message"]
assert "don't have access to this assistant" in response.json()["detail"]
def test_authorized_persona_access_returns_filtered_providers(
@@ -245,4 +245,4 @@ def test_nonexistent_persona_returns_404(
# Should return 404
assert response.status_code == 404
assert "Persona not found" in response.json()["message"]
assert "Persona not found" in response.json()["detail"]

View File

@@ -42,78 +42,6 @@ class NightlyProviderConfig(BaseModel):
strict: bool
def _stringify_custom_config_value(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, (dict, list)):
return json.dumps(value)
return str(value)
def _looks_like_vertex_credentials_payload(
raw_custom_config: dict[object, object],
) -> bool:
normalized_keys = {str(key).strip().lower() for key in raw_custom_config}
provider_specific_keys = {
"vertex_credentials",
"credentials_file",
"vertex_credentials_file",
"google_application_credentials",
"vertex_location",
"location",
"vertex_region",
"region",
}
if normalized_keys & provider_specific_keys:
return False
normalized_type = str(raw_custom_config.get("type", "")).strip().lower()
if normalized_type not in {"service_account", "external_account"}:
return False
# Service account JSON usually includes private_key/client_email, while external
# account JSON includes credential_source. Either shape should be accepted.
has_service_account_markers = any(
key in normalized_keys for key in {"private_key", "client_email"}
)
has_external_account_markers = "credential_source" in normalized_keys
return has_service_account_markers or has_external_account_markers
def _normalize_custom_config(
provider: str, raw_custom_config: dict[object, object]
) -> dict[str, str]:
if provider == "vertex_ai" and _looks_like_vertex_credentials_payload(
raw_custom_config
):
return {"vertex_credentials": json.dumps(raw_custom_config)}
normalized: dict[str, str] = {}
for raw_key, raw_value in raw_custom_config.items():
key = str(raw_key).strip()
key_lower = key.lower()
if provider == "vertex_ai":
if key_lower in {
"vertex_credentials",
"credentials_file",
"vertex_credentials_file",
"google_application_credentials",
}:
key = "vertex_credentials"
elif key_lower in {
"vertex_location",
"location",
"vertex_region",
"region",
}:
key = "vertex_location"
normalized[key] = _stringify_custom_config_value(raw_value)
return normalized
def _env_true(env_var: str, default: bool = False) -> bool:
value = os.environ.get(env_var)
if value is None:
@@ -152,9 +80,7 @@ def _load_provider_config() -> NightlyProviderConfig:
parsed = json.loads(custom_config_json)
if not isinstance(parsed, dict):
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
custom_config = _normalize_custom_config(
provider=provider, raw_custom_config=parsed
)
custom_config = {str(key): str(value) for key, value in parsed.items()}
if provider == "ollama_chat" and api_key and not custom_config:
custom_config = {"OLLAMA_API_KEY": api_key}
@@ -222,23 +148,6 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
),
)
if config.provider == "vertex_ai":
has_vertex_credentials = bool(
config.custom_config and config.custom_config.get("vertex_credentials")
)
if not has_vertex_credentials:
configured_keys = (
sorted(config.custom_config.keys()) if config.custom_config else []
)
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' "
f"for provider '{config.provider}'. "
f"Found keys: {configured_keys}"
),
)
def _assert_integration_mode_enabled() -> None:
assert (
@@ -284,7 +193,6 @@ def _create_provider_payload(
return {
"name": provider_name,
"provider": provider,
"model": model_name,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
@@ -300,23 +208,24 @@ def _create_provider_payload(
}
def _ensure_provider_is_default(
provider_id: int, model_name: str, admin_user: DATestUser
) -> None:
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
list_response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
list_response.raise_for_status()
default_text = list_response.json().get("default_text")
assert default_text is not None, "Expected a default provider after setting default"
assert default_text.get("provider_id") == provider_id, (
f"Expected provider {provider_id} to be default, "
f"found {default_text.get('provider_id')}"
providers = list_response.json()
current_default = next(
(provider for provider in providers if provider.get("is_default_provider")),
None,
)
assert (
default_text.get("model_name") == model_name
), f"Expected default model {model_name}, found {default_text.get('model_name')}"
current_default is not None
), "Expected a default provider after setting provider as default"
assert (
current_default["id"] == provider_id
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
def _run_chat_assertions(
@@ -417,9 +326,8 @@ def _create_and_test_provider_for_model(
try:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
headers=admin_user.headers,
json={"provider_id": provider_id, "model_name": model_name},
)
assert set_default_response.status_code == 200, (
f"Setting default provider failed for provider={config.provider} "
@@ -427,9 +335,7 @@ def _create_and_test_provider_for_model(
f"{set_default_response.text}"
)
_ensure_provider_is_default(
provider_id=provider_id, model_name=model_name, admin_user=admin_user
)
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
_run_chat_assertions(
admin_user=admin_user,
search_tool_id=search_tool_id,

View File

@@ -1,3 +1,4 @@
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -364,6 +365,7 @@ def test_update_contextual_rag_missing_model_name(
assert "Provider name and model name are required" in response.json()["detail"]
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_with_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -392,6 +394,7 @@ def test_set_new_search_settings_with_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_without_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -416,6 +419,7 @@ def test_set_new_search_settings_without_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_then_update_inference_settings(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -453,6 +457,7 @@ def test_set_new_then_update_inference_settings(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_replaces_previous_secondary(
reset: None, # noqa: ARG001
admin_user: DATestUser,

View File

@@ -281,10 +281,9 @@ class TestApplyLicenseStatusToSettings:
}
class TestSettingsDefaults:
"""Verify Settings model defaults for CE deployments."""
class TestSettingsDefaultEEDisabled:
"""Verify the Settings model defaults ee_features_enabled to False."""
def test_default_ee_features_disabled(self) -> None:
"""CE default: ee_features_enabled is False."""
settings = Settings()
assert settings.ee_features_enabled is False

View File

@@ -2,6 +2,7 @@
import pytest
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
from onyx.chat.llm_loop import _try_fallback_tool_extraction
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.models import ChatLoadedFile
@@ -13,11 +14,22 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
class _StubConfig:
def __init__(self, model_provider: str) -> None:
self.model_provider = model_provider
class _StubLLM:
def __init__(self, model_provider: str) -> None:
self.config = _StubConfig(model_provider=model_provider)
def create_message(
content: str, message_type: MessageType, token_count: int | None = None
) -> ChatMessageSimple:
@@ -934,6 +946,37 @@ class TestForgottenFileMetadata:
assert "moby_dick.txt" in forgotten.message
class TestBedrockToolConfigGuard:
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.BEDROCK)
history = [
create_message("Question", MessageType.USER, 5),
create_assistant_with_tool_call("tc_1", "search", 5),
create_tool_response("tc_1", "Tool output", 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is True
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.BEDROCK)
history = [
create_message("Question", MessageType.USER, 5),
create_message("Answer", MessageType.ASSISTANT, 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is False
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.OPENAI)
history = [
create_message("Question", MessageType.USER, 5),
create_assistant_with_tool_call("tc_1", "search", 5),
create_tool_response("tc_1", "Tool output", 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is False
class TestFallbackToolExtraction:
def _tool_defs(self) -> list[dict]:
return [

View File

@@ -1214,218 +1214,3 @@ def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
# The env lock context manager should never have been called
mock_env_lock.assert_not_called()
# ---- Tests for Bedrock tool content stripping ----
def test_messages_contain_tool_content_with_tool_role() -> None:
from onyx.llm.multi_llm import _messages_contain_tool_content
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "I'll search for that."},
{"role": "tool", "content": "search results", "tool_call_id": "tc_1"},
]
assert _messages_contain_tool_content(messages) is True
def test_messages_contain_tool_content_with_tool_calls() -> None:
from onyx.llm.multi_llm import _messages_contain_tool_content
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "tc_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
},
]
assert _messages_contain_tool_content(messages) is True
def test_messages_contain_tool_content_without_tools() -> None:
from onyx.llm.multi_llm import _messages_contain_tool_content
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
assert _messages_contain_tool_content(messages) is False
def test_strip_tool_content_converts_assistant_tool_calls_to_text() -> None:
from onyx.llm.multi_llm import _strip_tool_content_from_messages
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Search for cats"},
{
"role": "assistant",
"content": "Let me search.",
"tool_calls": [
{
"id": "tc_1",
"type": "function",
"function": {
"name": "search",
"arguments": '{"query": "cats"}',
},
}
],
},
{
"role": "tool",
"content": "Found 3 results about cats.",
"tool_call_id": "tc_1",
},
{"role": "assistant", "content": "Here are the results."},
]
result = _strip_tool_content_from_messages(messages)
assert len(result) == 4
# First message unchanged
assert result[0] == {"role": "user", "content": "Search for cats"}
# Assistant with tool calls → plain text
assert result[1]["role"] == "assistant"
assert "tool_calls" not in result[1]
assert "Let me search." in result[1]["content"]
assert "[Tool Call]" in result[1]["content"]
assert "search" in result[1]["content"]
assert "tc_1" in result[1]["content"]
# Tool response → user message
assert result[2]["role"] == "user"
assert "[Tool Result]" in result[2]["content"]
assert "tc_1" in result[2]["content"]
assert "Found 3 results about cats." in result[2]["content"]
# Final assistant message unchanged
assert result[3] == {"role": "assistant", "content": "Here are the results."}
def test_strip_tool_content_handles_assistant_with_no_text_content() -> None:
from onyx.llm.multi_llm import _strip_tool_content_from_messages
messages: list[dict[str, Any]] = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "tc_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
},
]
result = _strip_tool_content_from_messages(messages)
assert result[0]["role"] == "assistant"
assert "[Tool Call]" in result[0]["content"]
assert "tool_calls" not in result[0]
def test_strip_tool_content_passes_through_non_tool_messages() -> None:
from onyx.llm.multi_llm import _strip_tool_content_from_messages
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
]
result = _strip_tool_content_from_messages(messages)
assert result == messages
def test_strip_tool_content_handles_list_content_blocks() -> None:
from onyx.llm.multi_llm import _strip_tool_content_from_messages
messages: list[dict[str, Any]] = [
{
"role": "assistant",
"content": [{"type": "text", "text": "Searching now."}],
"tool_calls": [
{
"id": "tc_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
},
{
"role": "tool",
"content": [
{"type": "text", "text": "result A"},
{"type": "text", "text": "result B"},
],
"tool_call_id": "tc_1",
},
]
result = _strip_tool_content_from_messages(messages)
# Assistant: list content flattened + tool call appended
assert result[0]["role"] == "assistant"
assert "Searching now." in result[0]["content"]
assert "[Tool Call]" in result[0]["content"]
assert isinstance(result[0]["content"], str)
# Tool: list content flattened into user message
assert result[1]["role"] == "user"
assert "result A" in result[1]["content"]
assert "result B" in result[1]["content"]
assert isinstance(result[1]["content"], str)
def test_strip_tool_content_merges_consecutive_tool_results() -> None:
"""Bedrock requires strict user/assistant alternation. Multiple parallel
tool results must be merged into a single user message."""
from onyx.llm.multi_llm import _strip_tool_content_from_messages
messages: list[dict[str, Any]] = [
{"role": "user", "content": "weather and news?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "tc_1",
"type": "function",
"function": {"name": "search_weather", "arguments": "{}"},
},
{
"id": "tc_2",
"type": "function",
"function": {"name": "search_news", "arguments": "{}"},
},
],
},
{"role": "tool", "content": "sunny 72F", "tool_call_id": "tc_1"},
{"role": "tool", "content": "headline news", "tool_call_id": "tc_2"},
{"role": "assistant", "content": "Here are the results."},
]
result = _strip_tool_content_from_messages(messages)
# user, assistant (flattened), user (merged tool results), assistant
assert len(result) == 4
roles = [m["role"] for m in result]
assert roles == ["user", "assistant", "user", "assistant"]
# Both tool results merged into one user message
merged = result[2]["content"]
assert "tc_1" in merged
assert "sunny 72F" in merged
assert "tc_2" in merged
assert "headline news" in merged

View File

@@ -104,102 +104,3 @@ def test_format_slack_message_ampersand_not_double_escaped() -> None:
assert "&amp;" in formatted
assert "&quot;" not in formatted
# -- Table rendering tests --
def test_table_renders_as_vertical_cards() -> None:
message = (
"| Feature | Status | Owner |\n"
"|---------|--------|-------|\n"
"| Auth | Done | Alice |\n"
"| Search | In Progress | Bob |\n"
)
formatted = format_slack_message(message)
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
# Cards separated by blank line
assert "Owner: Alice\n\n*Search*" in formatted
# No raw pipe-and-dash table syntax
assert "---|" not in formatted
def test_table_single_column() -> None:
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
formatted = format_slack_message(message)
assert "*Alice*" in formatted
assert "*Bob*" in formatted
def test_table_embedded_in_text() -> None:
message = (
"Here are the results:\n\n"
"| Item | Count |\n"
"|------|-------|\n"
"| Apples | 5 |\n"
"\n"
"That's all."
)
formatted = format_slack_message(message)
assert "Here are the results:" in formatted
assert "*Apples*\n • Count: 5" in formatted
assert "That's all." in formatted
def test_table_with_formatted_cells() -> None:
message = (
"| Name | Link |\n"
"|------|------|\n"
"| **Alice** | [profile](https://example.com) |\n"
)
formatted = format_slack_message(message)
# Bold cell should not double-wrap: *Alice* not **Alice**
assert "*Alice*" in formatted
assert "**Alice**" not in formatted
assert "<https://example.com|profile>" in formatted
def test_table_with_alignment_specifiers() -> None:
message = (
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
)
formatted = format_slack_message(message)
assert "*a*\n • Center: b\n • Right: c" in formatted
def test_two_tables_in_same_message_use_independent_headers() -> None:
message = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"| X | Y | Z |\n"
"|---|---|---|\n"
"| p | q | r |\n"
)
formatted = format_slack_message(message)
assert "*1*\n • B: 2" in formatted
assert "*p*\n • Y: q\n • Z: r" in formatted
def test_table_empty_first_column_no_bare_asterisks() -> None:
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
formatted = format_slack_message(message)
# Empty title should not produce "**" (bare asterisks)
assert "**" not in formatted
assert " • Status: Done" in formatted

View File

@@ -87,8 +87,7 @@ def test_python_tool_available_when_health_check_passes(
mock_client = MagicMock()
mock_client.health.return_value = True
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
mock_client_cls.return_value = mock_client
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is True
@@ -110,8 +109,7 @@ def test_python_tool_unavailable_when_health_check_fails(
mock_client = MagicMock()
mock_client.health.return_value = False
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
mock_client_cls.return_value = mock_client
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False

3
cli/.gitignore vendored
View File

@@ -1,3 +0,0 @@
onyx-cli
cli
onyx.cli

View File

@@ -1,118 +0,0 @@
# Onyx CLI
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
## Installation
```shell
pip install onyx-cli
```
Or with uv:
```shell
uv pip install onyx-cli
```
## Setup
Run the interactive setup:
```shell
onyx-cli configure
```
This prompts for your Onyx server URL and API key, tests the connection, and saves config to `~/.config/onyx-cli/config.json`.
Environment variables override config file values:
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
## Usage
### Interactive chat (default)
```shell
onyx-cli
```
### One-shot question
```shell
onyx-cli ask "What is our company's PTO policy?"
onyx-cli ask --agent-id 5 "Summarize this topic"
onyx-cli ask --json "Hello"
```
| Flag | Description |
|------|-------------|
| `--agent-id <int>` | Agent ID to use (overrides default) |
| `--json` | Output raw NDJSON events instead of plain text |
### List agents
```shell
onyx-cli agents
onyx-cli agents --json
```
## Commands
| Command | Description |
|---------|-------------|
| `chat` | Launch the interactive chat TUI (default) |
| `ask` | Ask a one-shot question (non-interactive) |
| `agents` | List available agents |
| `configure` | Configure server URL and API key |
## Slash Commands (in TUI)
| Command | Description |
|---------|-------------|
| `/help` | Show help message |
| `/new` | Start a new chat session |
| `/agent` | List and switch agents |
| `/attach <path>` | Attach a file to next message |
| `/sessions` | List recent chat sessions |
| `/clear` | Clear the chat display |
| `/configure` | Re-run connection setup |
| `/connectors` | Open connectors in browser |
| `/settings` | Open settings in browser |
| `/quit` | Exit Onyx CLI |
## Keyboard Shortcuts
| Key | Action |
|-----|--------|
| `Enter` | Send message |
| `Escape` | Cancel current generation |
| `Ctrl+O` | Toggle source citations |
| `Ctrl+D` | Quit (press twice) |
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
| `Page Up` / `Page Down` | Scroll half page |
## Building from Source
Requires [Go 1.24+](https://go.dev/dl/).
```shell
cd cli
go build -o onyx-cli .
```
## Development
```shell
# Run tests
go test ./...
# Build
go build -o onyx-cli .
# Lint
staticcheck ./...
```

View File

@@ -1,63 +0,0 @@
package cmd
import (
"encoding/json"
"fmt"
"text/tabwriter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/spf13/cobra"
)
func newAgentsCmd() *cobra.Command {
var agentsJSON bool
cmd := &cobra.Command{
Use: "agents",
Short: "List available agents",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
}
client := api.NewClient(cfg)
agents, err := client.ListAgents()
if err != nil {
return fmt.Errorf("failed to list agents: %w", err)
}
if agentsJSON {
data, err := json.MarshalIndent(agents, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal agents: %w", err)
}
fmt.Println(string(data))
return nil
}
if len(agents) == 0 {
fmt.Println("No agents available.")
return nil
}
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "ID\tNAME\tDESCRIPTION")
for _, a := range agents {
desc := a.Description
if len(desc) > 60 {
desc = desc[:57] + "..."
}
_, _ = fmt.Fprintf(w, "%d\t%s\t%s\n", a.ID, a.Name, desc)
}
_ = w.Flush()
return nil
},
}
cmd.Flags().BoolVar(&agentsJSON, "json", false, "Output agents as JSON")
return cmd
}

View File

@@ -1,103 +0,0 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/spf13/cobra"
)
func newAskCmd() *cobra.Command {
var (
askAgentID int
askJSON bool
)
cmd := &cobra.Command{
Use: "ask [question]",
Short: "Ask a one-shot question (non-interactive)",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
}
question := args[0]
agentID := cfg.DefaultAgentID
if cmd.Flags().Changed("agent-id") {
agentID = askAgentID
}
client := api.NewClient(cfg)
parentID := -1
ch := client.SendMessageStream(
context.Background(),
question,
nil,
agentID,
&parentID,
nil,
)
var lastErr error
gotStop := false
for event := range ch {
if askJSON {
wrapped := struct {
Type string `json:"type"`
Event models.StreamEvent `json:"event"`
}{
Type: event.EventType(),
Event: event,
}
data, err := json.Marshal(wrapped)
if err != nil {
return fmt.Errorf("error marshaling event: %w", err)
}
fmt.Println(string(data))
if _, ok := event.(models.ErrorEvent); ok {
lastErr = fmt.Errorf("%s", event.(models.ErrorEvent).Error)
}
if _, ok := event.(models.StopEvent); ok {
gotStop = true
}
continue
}
switch e := event.(type) {
case models.MessageDeltaEvent:
fmt.Print(e.Content)
case models.ErrorEvent:
return fmt.Errorf("%s", e.Error)
case models.StopEvent:
fmt.Println()
return nil
}
}
if lastErr != nil {
return lastErr
}
if !gotStop {
if !askJSON {
fmt.Println()
}
return fmt.Errorf("stream ended unexpectedly")
}
if !askJSON {
fmt.Println()
}
return nil
},
}
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
// Suppress cobra's default error/usage on RunE errors
return cmd
}

View File

@@ -1,33 +0,0 @@
package cmd
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
)
func newChatCmd() *cobra.Command {
return &cobra.Command{
Use: "chat",
Short: "Launch the interactive chat TUI (default)",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
// First-run: onboarding
if !config.ConfigExists() || !cfg.IsConfigured() {
result := onboarding.Run(&cfg)
if result == nil {
return nil
}
cfg = *result
}
m := tui.NewModel(cfg)
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
_, err := p.Run()
return err
},
}
}

View File

@@ -1,19 +0,0 @@
package cmd
import (
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/spf13/cobra"
)
func newConfigureCmd() *cobra.Command {
return &cobra.Command{
Use: "configure",
Short: "Configure server URL and API key",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
onboarding.Run(&cfg)
return nil
},
}
}

View File

@@ -1,40 +0,0 @@
// Package cmd implements Cobra CLI commands for the Onyx CLI.
package cmd
import "github.com/spf13/cobra"
// Version and Commit are set via ldflags at build time.
var (
Version string
Commit string
)
func fullVersion() string {
if Commit != "" && Commit != "none" && len(Commit) > 7 {
return Version + " (" + Commit[:7] + ")"
}
return Version
}
// Execute creates and runs the root command.
func Execute() error {
rootCmd := &cobra.Command{
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
Version: fullVersion(),
}
// Register subcommands
chatCmd := newChatCmd()
rootCmd.AddCommand(chatCmd)
rootCmd.AddCommand(newAskCmd())
rootCmd.AddCommand(newAgentsCmd())
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
// Default command is chat
rootCmd.RunE = chatCmd.RunE
return rootCmd.Execute()
}

View File

@@ -1,41 +0,0 @@
package cmd
import (
"fmt"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/spf13/cobra"
)
func newValidateConfigCmd() *cobra.Command {
return &cobra.Command{
Use: "validate-config",
Short: "Validate configuration and test server connection",
RunE: func(cmd *cobra.Command, args []string) error {
// Check config file
if !config.ConfigExists() {
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
}
cfg := config.Load()
// Check API key
if !cfg.IsConfigured() {
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
// Test connection
client := api.NewClient(cfg)
if err := client.TestConnection(); err != nil {
return fmt.Errorf("connection failed: %w", err)
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
return nil
},
}
}

View File

@@ -1,45 +0,0 @@
module github.com/onyx-dot-app/onyx/cli
go 1.26.0
require (
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.4
github.com/charmbracelet/glamour v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/spf13/cobra v1.9.1
golang.org/x/term v0.22.0
golang.org/x/text v0.34.0
)
require (
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yuin/goldmark v1.7.4 // indirect
github.com/yuin/goldmark-emoji v1.0.3 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.30.0 // indirect
)

View File

@@ -1,94 +0,0 @@
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,279 +0,0 @@
// Package api provides the HTTP client for communicating with the Onyx server.
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// Client is the Onyx API client.
type Client struct {
baseURL string
apiKey string
httpClient *http.Client // default 30s timeout for quick requests
longHTTPClient *http.Client // 5min timeout for streaming/uploads
}
// NewClient creates a new API client from config.
func NewClient(cfg config.OnyxCliConfig) *Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
return &Client{
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
apiKey: cfg.APIKey,
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
},
longHTTPClient: &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
},
}
}
// UpdateConfig replaces the client's config.
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
c.apiKey = cfg.APIKey
}
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(context.Background(), method, c.baseURL+path, body)
if err != nil {
return nil, err
}
if c.apiKey != "" {
bearer := "Bearer " + c.apiKey
req.Header.Set("Authorization", bearer)
req.Header.Set("X-Onyx-Authorization", bearer)
}
return req, nil
}
func (c *Client) doJSON(method, path string, reqBody any, result any) error {
var body io.Reader
if reqBody != nil {
data, err := json.Marshal(reqBody)
if err != nil {
return err
}
body = bytes.NewReader(data)
}
req, err := c.newRequest(method, path, body)
if err != nil {
return err
}
if reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
}
if result != nil {
return json.NewDecoder(resp.Body).Decode(result)
}
return nil
}
// TestConnection checks if the server is reachable and credentials are valid.
// Returns nil on success, or an error with a descriptive message on failure.
func (c *Client) TestConnection() error {
// Step 1: Basic reachability
req, err := c.newRequest("GET", "/", nil)
if err != nil {
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
}
_ = resp.Body.Close()
serverHeader := strings.ToLower(resp.Header.Get("Server"))
if resp.StatusCode == 403 {
if strings.Contains(serverHeader, "awselb") || strings.Contains(serverHeader, "amazons3") {
return fmt.Errorf("blocked by AWS load balancer (HTTP 403 on all requests).\n Your IP address may not be in the ALB's security group or WAF allowlist")
}
return fmt.Errorf("HTTP 403 on base URL — the server is blocking all traffic.\n This is likely a firewall, WAF, or IP allowlist restriction")
}
// Step 2: Authenticated check
req2, err := c.newRequest("GET", "/api/me", nil)
if err != nil {
return fmt.Errorf("server reachable but API error: %w", err)
}
resp2, err := c.longHTTPClient.Do(req2)
if err != nil {
return fmt.Errorf("server reachable but API error: %w", err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode == 200 {
return nil
}
bodyBytes, _ := io.ReadAll(io.LimitReader(resp2.Body, 300))
body := string(bodyBytes)
isHTML := strings.HasPrefix(strings.TrimSpace(body), "<")
respServer := strings.ToLower(resp2.Header.Get("Server"))
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
if isHTML || strings.Contains(respServer, "awselb") {
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
}
if resp2.StatusCode == 401 {
return fmt.Errorf("invalid API key or token.\n %s", body)
}
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
}
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
if body != "" {
detail += fmt.Sprintf("\n Response: %s", body)
}
return fmt.Errorf("%s", detail)
}
// ListAgents returns visible agents.
func (c *Client) ListAgents() ([]models.AgentSummary, error) {
var raw []models.AgentSummary
if err := c.doJSON("GET", "/api/persona", nil, &raw); err != nil {
return nil, err
}
var result []models.AgentSummary
for _, p := range raw {
if p.IsVisible {
result = append(result, p)
}
}
return result, nil
}
// ListChatSessions returns recent chat sessions.
func (c *Client) ListChatSessions() ([]models.ChatSessionDetails, error) {
var resp struct {
Sessions []models.ChatSessionDetails `json:"sessions"`
}
if err := c.doJSON("GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
return nil, err
}
return resp.Sessions, nil
}
// GetChatSession returns full details for a session.
func (c *Client) GetChatSession(sessionID string) (*models.ChatSessionDetailResponse, error) {
var resp models.ChatSessionDetailResponse
if err := c.doJSON("GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// RenameChatSession renames a session. If name is empty, the backend auto-generates one.
func (c *Client) RenameChatSession(sessionID string, name *string) (string, error) {
payload := map[string]any{
"chat_session_id": sessionID,
}
if name != nil {
payload["name"] = *name
}
var resp struct {
NewName string `json:"new_name"`
}
if err := c.doJSON("PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
return "", err
}
return resp.NewName, nil
}
// UploadFile uploads a file and returns a file descriptor.
func (c *Client) UploadFile(filePath string) (*models.FileDescriptorPayload, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer func() { _ = file.Close() }()
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("files", filepath.Base(filePath))
if err != nil {
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
return nil, err
}
_ = writer.Close()
req, err := c.newRequest("POST", "/api/user/projects/file/upload", &buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := c.longHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return nil, &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(body)}
}
var snapshot models.CategorizedFilesSnapshot
if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil {
return nil, err
}
if len(snapshot.UserFiles) == 0 {
return nil, &OnyxAPIError{StatusCode: 400, Detail: "File upload returned no files"}
}
uf := snapshot.UserFiles[0]
return &models.FileDescriptorPayload{
ID: uf.FileID,
Type: uf.ChatFileType,
Name: filepath.Base(filePath),
}, nil
}
// StopChatSession sends a stop signal for a streaming session (best-effort).
func (c *Client) StopChatSession(sessionID string) {
req, err := c.newRequest("POST", "/api/chat/stop-chat-session/"+sessionID, nil)
if err != nil {
return
}
resp, err := c.httpClient.Do(req)
if err != nil {
return
}
_ = resp.Body.Close()
}

View File

@@ -1,13 +0,0 @@
package api
import "fmt"
// OnyxAPIError is returned when an Onyx API call fails.
type OnyxAPIError struct {
StatusCode int
Detail string
}
func (e *OnyxAPIError) Error() string {
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
}

View File

@@ -1,136 +0,0 @@
package api
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/parser"
)
// StreamEventMsg wraps a StreamEvent for Bubble Tea.
type StreamEventMsg struct {
Event models.StreamEvent
}
// StreamDoneMsg signals the stream has ended.
type StreamDoneMsg struct {
Err error
}
// SendMessageStream starts streaming a chat message response.
// It reads NDJSON lines, parses them, and sends events on the returned channel.
// The goroutine stops when ctx is cancelled or the stream ends.
func (c *Client) SendMessageStream(
ctx context.Context,
message string,
chatSessionID *string,
agentID int,
parentMessageID *int,
fileDescriptors []models.FileDescriptorPayload,
) <-chan models.StreamEvent {
ch := make(chan models.StreamEvent, 64)
go func() {
defer close(ch)
payload := models.SendMessagePayload{
Message: message,
ParentMessageID: parentMessageID,
FileDescriptors: fileDescriptors,
Origin: "api",
IncludeCitations: true,
Stream: true,
}
if payload.FileDescriptors == nil {
payload.FileDescriptors = []models.FileDescriptorPayload{}
}
if chatSessionID != nil {
payload.ChatSessionID = chatSessionID
} else {
payload.ChatSessionInfo = &models.ChatSessionCreationInfo{AgentID: agentID}
}
body, err := json.Marshal(payload)
if err != nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("marshal error: %v", err), IsRetryable: false}
return
}
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
if err != nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
return
}
req.Body = io.NopCloser(bytes.NewReader(body))
req.ContentLength = int64(len(body))
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
bearer := "Bearer " + c.apiKey
req.Header.Set("Authorization", bearer)
req.Header.Set("X-Onyx-Authorization", bearer)
}
resp, err := c.longHTTPClient.Do(req)
if err != nil {
if ctx.Err() != nil {
return // cancelled
}
ch <- models.ErrorEvent{Error: fmt.Sprintf("connection error: %v", err), IsRetryable: true}
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
var respBody [4096]byte
n, _ := resp.Body.Read(respBody[:])
ch <- models.ErrorEvent{
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody[:n])),
IsRetryable: resp.StatusCode >= 500,
}
return
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
for scanner.Scan() {
if ctx.Err() != nil {
return
}
event := parser.ParseStreamLine(scanner.Text())
if event != nil {
select {
case ch <- event:
case <-ctx.Done():
return
}
}
}
if err := scanner.Err(); err != nil && ctx.Err() == nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("stream read error: %v", err), IsRetryable: true}
}
}()
return ch
}
// WaitForStreamEvent returns a tea.Cmd that reads one event from the channel.
// On channel close, it returns StreamDoneMsg.
func WaitForStreamEvent(ch <-chan models.StreamEvent) tea.Cmd {
return func() tea.Msg {
event, ok := <-ch
if !ok {
return StreamDoneMsg{}
}
return StreamEventMsg{Event: event}
}
}

View File

@@ -1,101 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
)
const (
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvAgentID = "ONYX_PERSONA_ID"
)
// OnyxCliConfig holds the CLI configuration.
type OnyxCliConfig struct {
ServerURL string `json:"server_url"`
APIKey string `json:"api_key"`
DefaultAgentID int `json:"default_persona_id"`
}
// DefaultConfig returns a config with default values.
func DefaultConfig() OnyxCliConfig {
return OnyxCliConfig{
ServerURL: "https://cloud.onyx.app",
APIKey: "",
DefaultAgentID: 0,
}
}
// IsConfigured returns true if the config has an API key.
func (c OnyxCliConfig) IsConfigured() bool {
return c.APIKey != ""
}
// configDir returns ~/.config/onyx-cli
func configDir() string {
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
return filepath.Join(xdg, "onyx-cli")
}
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".", ".config", "onyx-cli")
}
return filepath.Join(home, ".config", "onyx-cli")
}
// ConfigFilePath returns the full path to the config file.
func ConfigFilePath() string {
return filepath.Join(configDir(), "config.json")
}
// ConfigExists checks if the config file exists on disk.
func ConfigExists() bool {
_, err := os.Stat(ConfigFilePath())
return err == nil
}
// Load reads config from file and applies environment variable overrides.
func Load() OnyxCliConfig {
cfg := DefaultConfig()
data, err := os.ReadFile(ConfigFilePath())
if err == nil {
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
}
}
// Environment overrides
if v := os.Getenv(EnvServerURL); v != "" {
cfg.ServerURL = v
}
if v := os.Getenv(EnvAPIKey); v != "" {
cfg.APIKey = v
}
if v := os.Getenv(EnvAgentID); v != "" {
if id, err := strconv.Atoi(v); err == nil {
cfg.DefaultAgentID = id
}
}
return cfg
}
// Save writes the config to disk, creating parent directories if needed.
func Save(cfg OnyxCliConfig) error {
dir := configDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(ConfigFilePath(), data, 0o600)
}

View File

@@ -1,215 +0,0 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func clearEnvVars(t *testing.T) {
t.Helper()
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID} {
t.Setenv(key, "")
if err := os.Unsetenv(key); err != nil {
t.Fatal(err)
}
}
}
func writeConfig(t *testing.T, dir string, data []byte) {
t.Helper()
onyxDir := filepath.Join(dir, "onyx-cli")
if err := os.MkdirAll(onyxDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644); err != nil {
t.Fatal(err)
}
}
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default server URL, got %s", cfg.ServerURL)
}
if cfg.APIKey != "" {
t.Errorf("expected empty API key, got %s", cfg.APIKey)
}
if cfg.DefaultAgentID != 0 {
t.Errorf("expected default agent ID 0, got %d", cfg.DefaultAgentID)
}
}
func TestIsConfigured(t *testing.T) {
cfg := DefaultConfig()
if cfg.IsConfigured() {
t.Error("empty config should not be configured")
}
cfg.APIKey = "some-key"
if !cfg.IsConfigured() {
t.Error("config with API key should be configured")
}
}
func TestLoadDefaults(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
cfg := Load()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default URL, got %s", cfg.ServerURL)
}
if cfg.APIKey != "" {
t.Errorf("expected empty key, got %s", cfg.APIKey)
}
}
func TestLoadFromFile(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
data, _ := json.Marshal(map[string]interface{}{
"server_url": "https://my-onyx.example.com",
"api_key": "test-key-123",
"default_persona_id": 5,
})
writeConfig(t, dir, data)
cfg := Load()
if cfg.ServerURL != "https://my-onyx.example.com" {
t.Errorf("got %s", cfg.ServerURL)
}
if cfg.APIKey != "test-key-123" {
t.Errorf("got %s", cfg.APIKey)
}
if cfg.DefaultAgentID != 5 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestLoadCorruptFile(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
writeConfig(t, dir, []byte("not valid json {{{"))
cfg := Load()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default URL on corrupt file, got %s", cfg.ServerURL)
}
}
func TestEnvOverrideServerURL(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvServerURL, "https://env-override.com")
cfg := Load()
if cfg.ServerURL != "https://env-override.com" {
t.Errorf("got %s", cfg.ServerURL)
}
}
func TestEnvOverrideAPIKey(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAPIKey, "env-key")
cfg := Load()
if cfg.APIKey != "env-key" {
t.Errorf("got %s", cfg.APIKey)
}
}
func TestEnvOverrideAgentID(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAgentID, "42")
cfg := Load()
if cfg.DefaultAgentID != 42 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestEnvOverrideInvalidAgentID(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAgentID, "not-a-number")
cfg := Load()
if cfg.DefaultAgentID != 0 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestEnvOverridesFileValues(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
data, _ := json.Marshal(map[string]interface{}{
"server_url": "https://file-url.com",
"api_key": "file-key",
})
writeConfig(t, dir, data)
t.Setenv(EnvServerURL, "https://env-url.com")
cfg := Load()
if cfg.ServerURL != "https://env-url.com" {
t.Errorf("env should override file, got %s", cfg.ServerURL)
}
if cfg.APIKey != "file-key" {
t.Errorf("file value should be kept, got %s", cfg.APIKey)
}
}
func TestSaveAndReload(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
cfg := OnyxCliConfig{
ServerURL: "https://saved.example.com",
APIKey: "saved-key",
DefaultAgentID: 10,
}
if err := Save(cfg); err != nil {
t.Fatal(err)
}
loaded := Load()
if loaded.ServerURL != "https://saved.example.com" {
t.Errorf("got %s", loaded.ServerURL)
}
if loaded.APIKey != "saved-key" {
t.Errorf("got %s", loaded.APIKey)
}
if loaded.DefaultAgentID != 10 {
t.Errorf("got %d", loaded.DefaultAgentID)
}
}
func TestSaveCreatesParentDirs(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
nested := filepath.Join(dir, "deep", "nested")
t.Setenv("XDG_CONFIG_HOME", nested)
if err := Save(OnyxCliConfig{APIKey: "test"}); err != nil {
t.Fatal(err)
}
if !ConfigExists() {
t.Error("config file should exist after save")
}
}

View File

@@ -1,193 +0,0 @@
package models
// StreamEvent is the interface for all parsed stream events.
type StreamEvent interface {
EventType() string
}
// Event type constants matching the Python StreamEventType enum.
const (
EventSessionCreated = "session_created"
EventMessageIDInfo = "message_id_info"
EventStop = "stop"
EventError = "error"
EventMessageStart = "message_start"
EventMessageDelta = "message_delta"
EventSearchStart = "search_tool_start"
EventSearchQueries = "search_tool_queries_delta"
EventSearchDocuments = "search_tool_documents_delta"
EventReasoningStart = "reasoning_start"
EventReasoningDelta = "reasoning_delta"
EventReasoningDone = "reasoning_done"
EventCitationInfo = "citation_info"
EventOpenURLStart = "open_url_start"
EventImageGenStart = "image_generation_start"
EventPythonToolStart = "python_tool_start"
EventCustomToolStart = "custom_tool_start"
EventFileReaderStart = "file_reader_start"
EventDeepResearchPlan = "deep_research_plan_start"
EventDeepResearchDelta = "deep_research_plan_delta"
EventResearchAgentStart = "research_agent_start"
EventIntermediateReport = "intermediate_report_start"
EventIntermediateReportDt = "intermediate_report_delta"
EventUnknown = "unknown"
)
// SessionCreatedEvent is emitted when a new chat session is created.
type SessionCreatedEvent struct {
ChatSessionID string
}
func (e SessionCreatedEvent) EventType() string { return EventSessionCreated }
// MessageIDEvent carries the user and agent message IDs.
type MessageIDEvent struct {
UserMessageID *int
ReservedAgentMessageID int
}
func (e MessageIDEvent) EventType() string { return EventMessageIDInfo }
// StopEvent signals the end of a stream.
type StopEvent struct {
Placement *Placement
StopReason *string
}
func (e StopEvent) EventType() string { return EventStop }
// ErrorEvent signals an error.
type ErrorEvent struct {
Placement *Placement
Error string
StackTrace *string
IsRetryable bool
}
func (e ErrorEvent) EventType() string { return EventError }
// MessageStartEvent signals the beginning of an agent message.
type MessageStartEvent struct {
Placement *Placement
Documents []SearchDoc
}
func (e MessageStartEvent) EventType() string { return EventMessageStart }
// MessageDeltaEvent carries a token of agent content.
type MessageDeltaEvent struct {
Placement *Placement
Content string
}
func (e MessageDeltaEvent) EventType() string { return EventMessageDelta }
// SearchStartEvent signals the beginning of a search.
type SearchStartEvent struct {
Placement *Placement
IsInternetSearch bool
}
func (e SearchStartEvent) EventType() string { return EventSearchStart }
// SearchQueriesEvent carries search queries.
type SearchQueriesEvent struct {
Placement *Placement
Queries []string
}
func (e SearchQueriesEvent) EventType() string { return EventSearchQueries }
// SearchDocumentsEvent carries found documents.
type SearchDocumentsEvent struct {
Placement *Placement
Documents []SearchDoc
}
func (e SearchDocumentsEvent) EventType() string { return EventSearchDocuments }
// ReasoningStartEvent signals the beginning of a reasoning block.
type ReasoningStartEvent struct {
Placement *Placement
}
func (e ReasoningStartEvent) EventType() string { return EventReasoningStart }
// ReasoningDeltaEvent carries reasoning text.
type ReasoningDeltaEvent struct {
Placement *Placement
Reasoning string
}
func (e ReasoningDeltaEvent) EventType() string { return EventReasoningDelta }
// ReasoningDoneEvent signals the end of reasoning.
type ReasoningDoneEvent struct {
Placement *Placement
}
func (e ReasoningDoneEvent) EventType() string { return EventReasoningDone }
// CitationEvent carries citation info.
type CitationEvent struct {
Placement *Placement
CitationNumber int
DocumentID string
}
func (e CitationEvent) EventType() string { return EventCitationInfo }
// ToolStartEvent signals the start of a tool usage.
type ToolStartEvent struct {
Placement *Placement
Type string // The specific event type (e.g. "open_url_start")
ToolName string
}
func (e ToolStartEvent) EventType() string { return e.Type }
// DeepResearchPlanStartEvent signals the start of a deep research plan.
type DeepResearchPlanStartEvent struct {
Placement *Placement
}
func (e DeepResearchPlanStartEvent) EventType() string { return EventDeepResearchPlan }
// DeepResearchPlanDeltaEvent carries deep research plan content.
type DeepResearchPlanDeltaEvent struct {
Placement *Placement
Content string
}
func (e DeepResearchPlanDeltaEvent) EventType() string { return EventDeepResearchDelta }
// ResearchAgentStartEvent signals a research sub-task.
type ResearchAgentStartEvent struct {
Placement *Placement
ResearchTask string
}
func (e ResearchAgentStartEvent) EventType() string { return EventResearchAgentStart }
// IntermediateReportStartEvent signals the start of an intermediate report.
type IntermediateReportStartEvent struct {
Placement *Placement
}
func (e IntermediateReportStartEvent) EventType() string { return EventIntermediateReport }
// IntermediateReportDeltaEvent carries intermediate report content.
type IntermediateReportDeltaEvent struct {
Placement *Placement
Content string
}
func (e IntermediateReportDeltaEvent) EventType() string { return EventIntermediateReportDt }
// UnknownEvent is a catch-all for unrecognized stream data.
type UnknownEvent struct {
Placement *Placement
RawData map[string]any
}
func (e UnknownEvent) EventType() string { return EventUnknown }

View File

@@ -1,112 +0,0 @@
// Package models defines API request/response types for the Onyx CLI.
package models
import "time"
// AgentSummary represents an agent from the API.
type AgentSummary struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
IsDefaultPersona bool `json:"is_default_persona"`
IsVisible bool `json:"is_visible"`
}
// ChatSessionSummary is a brief session listing.
type ChatSessionSummary struct {
ID string `json:"id"`
Name *string `json:"name"`
AgentID *int `json:"persona_id"`
Created time.Time `json:"time_created"`
}
// ChatSessionDetails is a session with timestamps as strings.
type ChatSessionDetails struct {
ID string `json:"id"`
Name *string `json:"name"`
AgentID *int `json:"persona_id"`
Created string `json:"time_created"`
Updated string `json:"time_updated"`
}
// ChatMessageDetail is a single message in a session.
type ChatMessageDetail struct {
MessageID int `json:"message_id"`
ParentMessage *int `json:"parent_message"`
LatestChildMessage *int `json:"latest_child_message"`
Message string `json:"message"`
MessageType string `json:"message_type"`
TimeSent string `json:"time_sent"`
Error *string `json:"error"`
}
// ChatSessionDetailResponse is the full session detail from the API.
type ChatSessionDetailResponse struct {
ChatSessionID string `json:"chat_session_id"`
Description *string `json:"description"`
AgentID *int `json:"persona_id"`
AgentName *string `json:"persona_name"`
Messages []ChatMessageDetail `json:"messages"`
}
// ChatFileType represents a file type for uploads.
type ChatFileType string
const (
ChatFileImage ChatFileType = "image"
ChatFileDoc ChatFileType = "document"
ChatFilePlainText ChatFileType = "plain_text"
ChatFileCSV ChatFileType = "csv"
)
// FileDescriptorPayload is a file descriptor for send-message requests.
type FileDescriptorPayload struct {
ID string `json:"id"`
Type ChatFileType `json:"type"`
Name string `json:"name,omitempty"`
}
// UserFileSnapshot represents an uploaded file.
type UserFileSnapshot struct {
ID string `json:"id"`
Name string `json:"name"`
FileID string `json:"file_id"`
ChatFileType ChatFileType `json:"chat_file_type"`
}
// CategorizedFilesSnapshot is the response from file upload.
type CategorizedFilesSnapshot struct {
UserFiles []UserFileSnapshot `json:"user_files"`
}
// ChatSessionCreationInfo is included when creating a new session inline.
type ChatSessionCreationInfo struct {
AgentID int `json:"persona_id"`
}
// SendMessagePayload is the request body for POST /api/chat/send-chat-message.
type SendMessagePayload struct {
Message string `json:"message"`
ChatSessionID *string `json:"chat_session_id,omitempty"`
ChatSessionInfo *ChatSessionCreationInfo `json:"chat_session_info,omitempty"`
ParentMessageID *int `json:"parent_message_id"`
FileDescriptors []FileDescriptorPayload `json:"file_descriptors"`
Origin string `json:"origin"`
IncludeCitations bool `json:"include_citations"`
Stream bool `json:"stream"`
}
// SearchDoc represents a document found during search.
type SearchDoc struct {
DocumentID string `json:"document_id"`
SemanticIdentifier string `json:"semantic_identifier"`
Link *string `json:"link"`
SourceType string `json:"source_type"`
}
// Placement indicates where a stream event belongs in the conversation.
type Placement struct {
TurnIndex int `json:"turn_index"`
TabIndex int `json:"tab_index"`
SubTurnIndex *int `json:"sub_turn_index"`
}

View File

@@ -1,169 +0,0 @@
// Package onboarding handles the first-run setup flow for Onyx CLI.
package onboarding
import (
"bufio"
"fmt"
"os"
"strings"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/onyx-dot-app/onyx/cli/internal/util"
"golang.org/x/term"
)
// Aliases for shared styles.
var (
boldStyle = util.BoldStyle
dimStyle = util.DimStyle
greenStyle = util.GreenStyle
redStyle = util.RedStyle
yellowStyle = util.YellowStyle
)
func getTermSize() (int, int) {
w, h, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
return 80, 24
}
return w, h
}
// Run executes the interactive onboarding flow.
// Returns the validated config, or nil if the user cancels.
func Run(existing *config.OnyxCliConfig) *config.OnyxCliConfig {
cfg := config.DefaultConfig()
if existing != nil {
cfg = *existing
}
w, h := getTermSize()
fmt.Print(tui.RenderSplashOnboarding(w, h))
fmt.Println()
fmt.Println(" Welcome to " + boldStyle.Render("Onyx CLI") + ".")
fmt.Println()
reader := bufio.NewReader(os.Stdin)
// Server URL
serverURL := prompt(reader, " Onyx server URL", cfg.ServerURL)
if serverURL == "" {
return nil
}
if !strings.HasPrefix(serverURL, "http://") && !strings.HasPrefix(serverURL, "https://") {
fmt.Println(" " + redStyle.Render("Server URL must start with http:// or https://"))
return nil
}
// API Key
fmt.Println()
fmt.Println(" " + dimStyle.Render("Need an API key? Press Enter to open the admin panel in your browser,"))
fmt.Println(" " + dimStyle.Render("or paste your key below."))
fmt.Println()
apiKey := promptSecret(" API key", cfg.APIKey)
if apiKey == "" {
// Open browser to API key page
url := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
fmt.Printf("\n Opening %s ...\n", url)
util.OpenBrowser(url)
fmt.Println(" " + dimStyle.Render("Copy your API key, then paste it here."))
fmt.Println()
apiKey = promptSecret(" API key", "")
if apiKey == "" {
fmt.Println("\n " + redStyle.Render("No API key provided. Exiting."))
return nil
}
}
// Test connection
cfg = config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: apiKey,
DefaultAgentID: cfg.DefaultAgentID,
}
fmt.Println("\n " + yellowStyle.Render("Testing connection..."))
client := api.NewClient(cfg)
if err := client.TestConnection(); err != nil {
fmt.Println(" " + redStyle.Render("Connection failed.") + " " + err.Error())
fmt.Println()
fmt.Println(" " + dimStyle.Render("Run ") + boldStyle.Render("onyx-cli configure") + dimStyle.Render(" to try again."))
return nil
}
if err := config.Save(cfg); err != nil {
fmt.Println(" " + redStyle.Render("Could not save config: "+err.Error()))
return nil
}
fmt.Println(" " + greenStyle.Render("Connected and authenticated."))
fmt.Println()
printQuickStart()
return &cfg
}
func promptSecret(label, defaultVal string) string {
if defaultVal != "" {
fmt.Printf("%s %s: ", label, dimStyle.Render("[hidden]"))
} else {
fmt.Printf("%s: ", label)
}
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println() // ReadPassword doesn't echo a newline
if err != nil {
return defaultVal
}
line := strings.TrimSpace(string(password))
if line == "" {
return defaultVal
}
return line
}
func prompt(reader *bufio.Reader, label, defaultVal string) string {
if defaultVal != "" {
fmt.Printf("%s %s: ", label, dimStyle.Render("["+defaultVal+"]"))
} else {
fmt.Printf("%s: ", label)
}
line, err := reader.ReadString('\n')
// ReadString may return partial data along with an error (e.g. EOF without newline)
line = strings.TrimSpace(line)
if line != "" {
return line
}
if err != nil {
return defaultVal
}
return defaultVal
}
func printQuickStart() {
fmt.Println(" " + boldStyle.Render("Quick start"))
fmt.Println()
fmt.Println(" Just type to chat with your Onyx agent.")
fmt.Println()
rows := [][2]string{
{"/help", "Show all commands"},
{"/attach", "Attach a file"},
{"/agent", "Switch agent"},
{"/new", "New conversation"},
{"/sessions", "Browse previous chats"},
{"Esc", "Cancel generation"},
{"Ctrl+D", "Quit"},
}
for _, r := range rows {
fmt.Printf(" %-12s %s\n", boldStyle.Render(r[0]), dimStyle.Render(r[1]))
}
fmt.Println()
}

View File

@@ -1,248 +0,0 @@
// Package parser handles NDJSON stream parsing for Onyx chat responses.
package parser
import (
"encoding/json"
"fmt"
"strings"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// ParseStreamLine parses a single NDJSON line into a typed StreamEvent.
// Returns nil for empty lines or unparseable content.
func ParseStreamLine(line string) models.StreamEvent {
line = strings.TrimSpace(line)
if line == "" {
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(line), &data); err != nil {
return models.ErrorEvent{Error: fmt.Sprintf("malformed stream data: %v", err), IsRetryable: false}
}
// Case 1: CreateChatSessionID
if _, ok := data["chat_session_id"]; ok {
if _, hasPlacement := data["placement"]; !hasPlacement {
sid, _ := data["chat_session_id"].(string)
return models.SessionCreatedEvent{ChatSessionID: sid}
}
}
// Case 2: MessageResponseIDInfo
if _, ok := data["reserved_assistant_message_id"]; ok {
reservedID := jsonInt(data["reserved_assistant_message_id"])
var userMsgID *int
if v, ok := data["user_message_id"]; ok && v != nil {
id := jsonInt(v)
userMsgID = &id
}
return models.MessageIDEvent{
UserMessageID: userMsgID,
ReservedAgentMessageID: reservedID,
}
}
// Case 3: StreamingError (top-level error without placement)
if _, ok := data["error"]; ok {
if _, hasPlacement := data["placement"]; !hasPlacement {
errStr, _ := data["error"].(string)
var stackTrace *string
if st, ok := data["stack_trace"].(string); ok {
stackTrace = &st
}
isRetryable := true
if v, ok := data["is_retryable"].(bool); ok {
isRetryable = v
}
return models.ErrorEvent{
Error: errStr,
StackTrace: stackTrace,
IsRetryable: isRetryable,
}
}
}
// Case 4: Packet with placement + obj
if rawPlacement, ok := data["placement"]; ok {
if rawObj, ok := data["obj"]; ok {
placement := parsePlacement(rawPlacement)
obj, _ := rawObj.(map[string]any)
if obj == nil {
return models.UnknownEvent{Placement: placement, RawData: data}
}
return parsePacketObj(obj, placement)
}
}
// Fallback
return models.UnknownEvent{RawData: data}
}
func parsePlacement(raw interface{}) *models.Placement {
m, ok := raw.(map[string]any)
if !ok {
return nil
}
p := &models.Placement{
TurnIndex: jsonInt(m["turn_index"]),
TabIndex: jsonInt(m["tab_index"]),
}
if v, ok := m["sub_turn_index"]; ok && v != nil {
st := jsonInt(v)
p.SubTurnIndex = &st
}
return p
}
func parsePacketObj(obj map[string]any, placement *models.Placement) models.StreamEvent {
objType, _ := obj["type"].(string)
switch objType {
case "stop":
var reason *string
if r, ok := obj["stop_reason"].(string); ok {
reason = &r
}
return models.StopEvent{Placement: placement, StopReason: reason}
case "error":
errMsg := "Unknown error"
if e, ok := obj["exception"]; ok {
errMsg = toString(e)
}
return models.ErrorEvent{Placement: placement, Error: errMsg, IsRetryable: true}
case "message_start":
var docs []models.SearchDoc
if rawDocs, ok := obj["final_documents"].([]any); ok {
docs = parseSearchDocs(rawDocs)
}
return models.MessageStartEvent{Placement: placement, Documents: docs}
case "message_delta":
content, _ := obj["content"].(string)
return models.MessageDeltaEvent{Placement: placement, Content: content}
case "search_tool_start":
isInternet, _ := obj["is_internet_search"].(bool)
return models.SearchStartEvent{Placement: placement, IsInternetSearch: isInternet}
case "search_tool_queries_delta":
var queries []string
if raw, ok := obj["queries"].([]any); ok {
for _, q := range raw {
if s, ok := q.(string); ok {
queries = append(queries, s)
}
}
}
return models.SearchQueriesEvent{Placement: placement, Queries: queries}
case "search_tool_documents_delta":
var docs []models.SearchDoc
if rawDocs, ok := obj["documents"].([]any); ok {
docs = parseSearchDocs(rawDocs)
}
return models.SearchDocumentsEvent{Placement: placement, Documents: docs}
case "reasoning_start":
return models.ReasoningStartEvent{Placement: placement}
case "reasoning_delta":
reasoning, _ := obj["reasoning"].(string)
return models.ReasoningDeltaEvent{Placement: placement, Reasoning: reasoning}
case "reasoning_done":
return models.ReasoningDoneEvent{Placement: placement}
case "citation_info":
return models.CitationEvent{
Placement: placement,
CitationNumber: jsonInt(obj["citation_number"]),
DocumentID: jsonString(obj["document_id"]),
}
case "open_url_start", "image_generation_start", "python_tool_start", "file_reader_start":
toolName := strings.ReplaceAll(strings.TrimSuffix(objType, "_start"), "_", " ")
toolName = cases.Title(language.English).String(toolName)
return models.ToolStartEvent{Placement: placement, Type: objType, ToolName: toolName}
case "custom_tool_start":
toolName := jsonString(obj["tool_name"])
if toolName == "" {
toolName = "Custom Tool"
}
return models.ToolStartEvent{Placement: placement, Type: models.EventCustomToolStart, ToolName: toolName}
case "deep_research_plan_start":
return models.DeepResearchPlanStartEvent{Placement: placement}
case "deep_research_plan_delta":
content, _ := obj["content"].(string)
return models.DeepResearchPlanDeltaEvent{Placement: placement, Content: content}
case "research_agent_start":
task, _ := obj["research_task"].(string)
return models.ResearchAgentStartEvent{Placement: placement, ResearchTask: task}
case "intermediate_report_start":
return models.IntermediateReportStartEvent{Placement: placement}
case "intermediate_report_delta":
content, _ := obj["content"].(string)
return models.IntermediateReportDeltaEvent{Placement: placement, Content: content}
default:
return models.UnknownEvent{Placement: placement, RawData: obj}
}
}
func parseSearchDocs(raw []any) []models.SearchDoc {
var docs []models.SearchDoc
for _, item := range raw {
m, ok := item.(map[string]any)
if !ok {
continue
}
doc := models.SearchDoc{
DocumentID: jsonString(m["document_id"]),
SemanticIdentifier: jsonString(m["semantic_identifier"]),
SourceType: jsonString(m["source_type"]),
}
if link, ok := m["link"].(string); ok {
doc.Link = &link
}
docs = append(docs, doc)
}
return docs
}
func jsonInt(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
default:
return 0
}
}
func jsonString(v any) string {
s, _ := v.(string)
return s
}
func toString(v any) string {
switch s := v.(type) {
case string:
return s
default:
b, _ := json.Marshal(v)
return string(b)
}
}

View File

@@ -1,419 +0,0 @@
package parser
import (
"encoding/json"
"testing"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
func TestEmptyLineReturnsNil(t *testing.T) {
for _, line := range []string{"", " ", "\n"} {
if ParseStreamLine(line) != nil {
t.Errorf("expected nil for %q", line)
}
}
}
func TestInvalidJSONReturnsErrorEvent(t *testing.T) {
for _, line := range []string{"not json", "{broken"} {
event := ParseStreamLine(line)
if event == nil {
t.Errorf("expected ErrorEvent for %q, got nil", line)
continue
}
if _, ok := event.(models.ErrorEvent); !ok {
t.Errorf("expected ErrorEvent for %q, got %T", line, event)
}
}
}
func TestSessionCreated(t *testing.T) {
line := mustJSON(map[string]interface{}{
"chat_session_id": "550e8400-e29b-41d4-a716-446655440000",
})
event := ParseStreamLine(line)
e, ok := event.(models.SessionCreatedEvent)
if !ok {
t.Fatalf("expected SessionCreatedEvent, got %T", event)
}
if e.ChatSessionID != "550e8400-e29b-41d4-a716-446655440000" {
t.Errorf("got %s", e.ChatSessionID)
}
}
func TestMessageIDInfo(t *testing.T) {
line := mustJSON(map[string]interface{}{
"user_message_id": 1,
"reserved_assistant_message_id": 2,
})
event := ParseStreamLine(line)
e, ok := event.(models.MessageIDEvent)
if !ok {
t.Fatalf("expected MessageIDEvent, got %T", event)
}
if e.UserMessageID == nil || *e.UserMessageID != 1 {
t.Errorf("expected user_message_id=1")
}
if e.ReservedAgentMessageID != 2 {
t.Errorf("got %d", e.ReservedAgentMessageID)
}
}
func TestMessageIDInfoNullUserID(t *testing.T) {
line := mustJSON(map[string]interface{}{
"user_message_id": nil,
"reserved_assistant_message_id": 5,
})
event := ParseStreamLine(line)
e, ok := event.(models.MessageIDEvent)
if !ok {
t.Fatalf("expected MessageIDEvent, got %T", event)
}
if e.UserMessageID != nil {
t.Error("expected nil user_message_id")
}
if e.ReservedAgentMessageID != 5 {
t.Errorf("got %d", e.ReservedAgentMessageID)
}
}
func TestTopLevelError(t *testing.T) {
line := mustJSON(map[string]interface{}{
"error": "Rate limit exceeded",
"stack_trace": "...",
"is_retryable": true,
})
event := ParseStreamLine(line)
e, ok := event.(models.ErrorEvent)
if !ok {
t.Fatalf("expected ErrorEvent, got %T", event)
}
if e.Error != "Rate limit exceeded" {
t.Errorf("got %s", e.Error)
}
if e.StackTrace == nil || *e.StackTrace != "..." {
t.Error("expected stack_trace")
}
if !e.IsRetryable {
t.Error("expected retryable")
}
}
func TestTopLevelErrorMinimal(t *testing.T) {
line := mustJSON(map[string]interface{}{
"error": "Something broke",
})
event := ParseStreamLine(line)
e, ok := event.(models.ErrorEvent)
if !ok {
t.Fatalf("expected ErrorEvent, got %T", event)
}
if e.Error != "Something broke" {
t.Errorf("got %s", e.Error)
}
if !e.IsRetryable {
t.Error("expected default retryable=true")
}
}
func makePacket(obj map[string]interface{}, turnIndex, tabIndex int) string {
return mustJSON(map[string]interface{}{
"placement": map[string]interface{}{"turn_index": turnIndex, "tab_index": tabIndex},
"obj": obj,
})
}
func TestStopPacket(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "stop", "stop_reason": "completed"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.StopEvent)
if !ok {
t.Fatalf("expected StopEvent, got %T", event)
}
if e.StopReason == nil || *e.StopReason != "completed" {
t.Error("expected stop_reason=completed")
}
if e.Placement == nil || e.Placement.TurnIndex != 0 {
t.Error("expected placement")
}
}
func TestStopPacketNoReason(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "stop"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.StopEvent)
if !ok {
t.Fatalf("expected StopEvent, got %T", event)
}
if e.StopReason != nil {
t.Error("expected nil stop_reason")
}
}
func TestMessageStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_start"}, 0, 0)
event := ParseStreamLine(line)
_, ok := event.(models.MessageStartEvent)
if !ok {
t.Fatalf("expected MessageStartEvent, got %T", event)
}
}
func TestMessageStartWithDocuments(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "message_start",
"final_documents": []interface{}{
map[string]interface{}{"document_id": "doc1", "semantic_identifier": "Doc 1"},
},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageStartEvent)
if !ok {
t.Fatalf("expected MessageStartEvent, got %T", event)
}
if len(e.Documents) != 1 || e.Documents[0].DocumentID != "doc1" {
t.Error("expected 1 document with id doc1")
}
}
func TestMessageDelta(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_delta", "content": "Hello"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Content != "Hello" {
t.Errorf("got %s", e.Content)
}
}
func TestMessageDeltaEmpty(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_delta", "content": ""}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Content != "" {
t.Errorf("expected empty, got %s", e.Content)
}
}
func TestSearchToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_start", "is_internet_search": true,
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchStartEvent)
if !ok {
t.Fatalf("expected SearchStartEvent, got %T", event)
}
if !e.IsInternetSearch {
t.Error("expected internet search")
}
}
func TestSearchToolQueries(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_queries_delta",
"queries": []interface{}{"query 1", "query 2"},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchQueriesEvent)
if !ok {
t.Fatalf("expected SearchQueriesEvent, got %T", event)
}
if len(e.Queries) != 2 || e.Queries[0] != "query 1" {
t.Error("unexpected queries")
}
}
func TestSearchToolDocuments(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_documents_delta",
"documents": []interface{}{
map[string]interface{}{"document_id": "d1", "semantic_identifier": "First Doc", "link": "http://example.com"},
map[string]interface{}{"document_id": "d2", "semantic_identifier": "Second Doc"},
},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchDocumentsEvent)
if !ok {
t.Fatalf("expected SearchDocumentsEvent, got %T", event)
}
if len(e.Documents) != 2 {
t.Errorf("expected 2 docs, got %d", len(e.Documents))
}
if e.Documents[0].Link == nil || *e.Documents[0].Link != "http://example.com" {
t.Error("expected link on first doc")
}
}
func TestReasoningStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "reasoning_start"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.ReasoningStartEvent); !ok {
t.Fatalf("expected ReasoningStartEvent, got %T", event)
}
}
func TestReasoningDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "reasoning_delta", "reasoning": "Let me think...",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ReasoningDeltaEvent)
if !ok {
t.Fatalf("expected ReasoningDeltaEvent, got %T", event)
}
if e.Reasoning != "Let me think..." {
t.Errorf("got %s", e.Reasoning)
}
}
func TestReasoningDone(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "reasoning_done"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.ReasoningDoneEvent); !ok {
t.Fatalf("expected ReasoningDoneEvent, got %T", event)
}
}
func TestCitationInfo(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "citation_info", "citation_number": 1, "document_id": "doc_abc",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.CitationEvent)
if !ok {
t.Fatalf("expected CitationEvent, got %T", event)
}
if e.CitationNumber != 1 || e.DocumentID != "doc_abc" {
t.Errorf("got %d, %s", e.CitationNumber, e.DocumentID)
}
}
func TestOpenURLStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "open_url_start"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.Type != "open_url_start" {
t.Errorf("got type %s", e.Type)
}
}
func TestPythonToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "python_tool_start", "code": "print('hi')",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.ToolName != "Python Tool" {
t.Errorf("got %s", e.ToolName)
}
}
func TestCustomToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "custom_tool_start", "tool_name": "MyTool",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.ToolName != "MyTool" {
t.Errorf("got %s", e.ToolName)
}
}
func TestDeepResearchPlanDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "deep_research_plan_delta", "content": "Step 1: ...",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.DeepResearchPlanDeltaEvent)
if !ok {
t.Fatalf("expected DeepResearchPlanDeltaEvent, got %T", event)
}
if e.Content != "Step 1: ..." {
t.Errorf("got %s", e.Content)
}
}
func TestResearchAgentStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "research_agent_start", "research_task": "Find info about X",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ResearchAgentStartEvent)
if !ok {
t.Fatalf("expected ResearchAgentStartEvent, got %T", event)
}
if e.ResearchTask != "Find info about X" {
t.Errorf("got %s", e.ResearchTask)
}
}
func TestIntermediateReportDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "intermediate_report_delta", "content": "Report text",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.IntermediateReportDeltaEvent)
if !ok {
t.Fatalf("expected IntermediateReportDeltaEvent, got %T", event)
}
if e.Content != "Report text" {
t.Errorf("got %s", e.Content)
}
}
func TestUnknownPacketType(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "section_end"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.UnknownEvent); !ok {
t.Fatalf("expected UnknownEvent, got %T", event)
}
}
func TestUnknownTopLevel(t *testing.T) {
line := mustJSON(map[string]interface{}{"some_unknown_field": "value"})
event := ParseStreamLine(line)
if _, ok := event.(models.UnknownEvent); !ok {
t.Fatalf("expected UnknownEvent, got %T", event)
}
}
func TestPlacementPreserved(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "message_delta", "content": "x",
}, 3, 1)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Placement == nil {
t.Fatal("expected placement")
}
if e.Placement.TurnIndex != 3 || e.Placement.TabIndex != 1 {
t.Errorf("got turn=%d tab=%d", e.Placement.TurnIndex, e.Placement.TabIndex)
}
}
func mustJSON(v interface{}) string {
b, err := json.Marshal(v)
if err != nil {
panic(err)
}
return string(b)
}

View File

@@ -1,627 +0,0 @@
// Package tui implements the Bubble Tea TUI for Onyx CLI.
package tui
import (
"context"
"fmt"
"strconv"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// Model is the root Bubble Tea model.
type Model struct {
config config.OnyxCliConfig
client *api.Client
viewport *viewport
input inputModel
status statusBar
width int
height int
// Chat state
chatSessionID *string
agentID int
agentName string
agents []models.AgentSummary
parentMessageID *int
isStreaming bool
streamCancel context.CancelFunc
streamCh <-chan models.StreamEvent
citations map[int]string
attachedFiles []models.FileDescriptorPayload
needsRename bool
agentStarted bool
// Quit state
quitPending bool
splashShown bool
initInputReady bool // true once terminal init responses have passed
}
// NewModel creates a new TUI model.
func NewModel(cfg config.OnyxCliConfig) Model {
client := api.NewClient(cfg)
parentID := -1
return Model{
config: cfg,
client: client,
viewport: newViewport(80),
input: newInputModel(),
status: newStatusBar(),
agentID: cfg.DefaultAgentID,
agentName: "Default",
parentMessageID: &parentID,
citations: make(map[int]string),
}
}
// Init initializes the model.
func (m Model) Init() tea.Cmd {
return loadAgentsCmd(m.client)
}
// Update handles messages.
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Filter out terminal query responses (OSC 11 background color, cursor
// position reports, etc.) that arrive as key events with raw escape content.
// These arrive split across multiple key events, so we use a brief window
// after startup to swallow them all.
if keyMsg, ok := msg.(tea.KeyMsg); ok && !m.initInputReady {
// During init, drop ALL key events — they're terminal query responses
_ = keyMsg
return m, nil
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.viewport.setWidth(msg.Width)
m.status.setWidth(msg.Width)
m.input.textInput.Width = msg.Width - 4
if !m.splashShown {
m.splashShown = true
// bottomHeight = sep + input + sep + status = 4 (approx)
viewportHeight := msg.Height - 4
if viewportHeight < 1 {
viewportHeight = msg.Height
}
m.viewport.addSplash(viewportHeight)
// Delay input focus to let terminal query responses flush
return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg {
return inputReadyMsg{}
})
}
return m, nil
case tea.MouseMsg:
switch msg.Button {
case tea.MouseButtonWheelUp:
m.viewport.scrollUp(3)
return m, nil
case tea.MouseButtonWheelDown:
m.viewport.scrollDown(3)
return m, nil
}
case tea.KeyMsg:
return m.handleKey(msg)
case submitMsg:
return m.handleSubmit(msg.text)
case fileDropMsg:
return m.handleFileDrop(msg.path)
case InitDoneMsg:
return m.handleInitDone(msg)
case api.StreamEventMsg:
return m.handleStreamEvent(msg)
case api.StreamDoneMsg:
return m.handleStreamDone(msg)
case AgentsLoadedMsg:
return m.handleAgentsLoaded(msg)
case SessionsLoadedMsg:
return m.handleSessionsLoaded(msg)
case SessionResumedMsg:
return m.handleSessionResumed(msg)
case FileUploadedMsg:
return m.handleFileUploaded(msg)
case inputReadyMsg:
m.initInputReady = true
m.input.textInput.Focus()
m.input.textInput.SetValue("")
return m, m.input.textInput.Cursor.BlinkCmd()
case resetQuitMsg:
m.quitPending = false
return m, nil
}
// Only forward messages to the text input after it's been focused
if m.splashShown {
var cmd tea.Cmd
m.input, cmd = m.input.update(msg)
return m, cmd
}
return m, nil
}
// View renders the UI.
// viewportHeight returns the number of visible chat rows, accounting for the
// dynamic bottom area (separator, menu, file badges, input, status bar).
func (m Model) viewportHeight() int {
menuHeight := 0
if m.input.menuVisible {
menuHeight = len(m.input.menuItems)
}
fileHeight := 0
if len(m.input.attachedFiles) > 0 {
fileHeight = 1
}
h := m.height - (1 + menuHeight + fileHeight + 1 + 1 + 1)
if h < 1 {
return 1
}
return h
}
func (m Model) View() string {
if m.width == 0 || m.height == 0 {
return ""
}
separator := lipgloss.NewStyle().Foreground(separatorColor).Render(
strings.Repeat("─", m.width),
)
menuView := m.input.viewMenu(m.width)
viewportHeight := m.viewportHeight()
var parts []string
parts = append(parts, m.viewport.view(viewportHeight))
parts = append(parts, separator)
if menuView != "" {
parts = append(parts, menuView)
}
parts = append(parts, m.input.viewInput())
parts = append(parts, separator)
parts = append(parts, m.status.view())
return strings.Join(parts, "\n")
}
// handleKey processes keyboard input.
func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.Type {
case tea.KeyEscape:
// Cancel streaming or close menu
if m.input.menuVisible {
m.input.menuVisible = false
return m, nil
}
if m.isStreaming {
return m.cancelStream()
}
// Dismiss picker
if m.viewport.pickerActive {
m.viewport.pickerActive = false
return m, nil
}
return m, nil
case tea.KeyCtrlD:
// If streaming, cancel first; require a fresh Ctrl+D pair to quit
if m.isStreaming {
return m.cancelStream()
}
if m.quitPending {
return m, tea.Quit
}
m.quitPending = true
m.viewport.addInfo("Press Ctrl+D again to quit.")
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
return resetQuitMsg{}
})
case tea.KeyCtrlO:
m.viewport.showSources = !m.viewport.showSources
return m, nil
case tea.KeyEnter:
// If picker is active, handle selection
if m.viewport.pickerActive && len(m.viewport.pickerItems) > 0 {
item := m.viewport.pickerItems[m.viewport.pickerIndex]
m.viewport.pickerActive = false
switch m.viewport.pickerType {
case pickerSession:
return cmdResume(m, item.id)
case pickerAgent:
return cmdSelectAgent(m, item.id)
}
}
case tea.KeyUp:
if m.viewport.pickerActive {
if m.viewport.pickerIndex > 0 {
m.viewport.pickerIndex--
}
return m, nil
}
case tea.KeyDown:
if m.viewport.pickerActive {
if m.viewport.pickerIndex < len(m.viewport.pickerItems)-1 {
m.viewport.pickerIndex++
}
return m, nil
}
case tea.KeyPgUp:
m.viewport.scrollUp(m.viewportHeight() / 2)
return m, nil
case tea.KeyPgDown:
m.viewport.scrollDown(m.viewportHeight() / 2)
return m, nil
case tea.KeyShiftUp:
m.viewport.scrollUp(3)
return m, nil
case tea.KeyShiftDown:
m.viewport.scrollDown(3)
return m, nil
}
// Pass to input
var cmd tea.Cmd
m.input, cmd = m.input.update(msg)
return m, cmd
}
func (m Model) handleSubmit(text string) (tea.Model, tea.Cmd) {
if strings.HasPrefix(text, "/") {
return handleSlashCommand(m, text)
}
return m.sendMessage(text)
}
func (m Model) handleFileDrop(path string) (tea.Model, tea.Cmd) {
return cmdAttach(m, path)
}
func (m Model) cancelStream() (Model, tea.Cmd) {
if m.streamCancel != nil {
m.streamCancel()
}
if m.chatSessionID != nil {
sid := *m.chatSessionID
go m.client.StopChatSession(sid)
}
m, cmd := m.finishStream(nil)
m.viewport.addInfo("Generation stopped.")
return m, cmd
}
func (m Model) sendMessage(message string) (Model, tea.Cmd) {
if m.isStreaming {
return m, nil
}
m.viewport.addUserMessage(message)
m.viewport.startAgent()
// Prepare file descriptors
fileDescs := make([]models.FileDescriptorPayload, len(m.attachedFiles))
copy(fileDescs, m.attachedFiles)
m.attachedFiles = nil
m.input.clearFiles()
m.isStreaming = true
m.agentStarted = false
m.citations = make(map[int]string)
m.status.setStreaming(true)
ctx, cancel := context.WithCancel(context.Background())
m.streamCancel = cancel
ch := m.client.SendMessageStream(
ctx,
message,
m.chatSessionID,
m.agentID,
m.parentMessageID,
fileDescs,
)
m.streamCh = ch
return m, api.WaitForStreamEvent(ch)
}
func (m Model) handleStreamEvent(msg api.StreamEventMsg) (tea.Model, tea.Cmd) {
// Ignore stale events after cancellation
if !m.isStreaming {
return m, nil
}
switch e := msg.Event.(type) {
case models.SessionCreatedEvent:
m.chatSessionID = &e.ChatSessionID
m.needsRename = true
m.status.setSession(e.ChatSessionID)
case models.MessageIDEvent:
m.parentMessageID = &e.ReservedAgentMessageID
case models.MessageStartEvent:
m.agentStarted = true
case models.MessageDeltaEvent:
m.agentStarted = true
m.viewport.appendToken(e.Content)
case models.SearchStartEvent:
if e.IsInternetSearch {
m.viewport.addInfo("Web search…")
} else {
m.viewport.addInfo("Searching…")
}
case models.SearchQueriesEvent:
if len(e.Queries) > 0 {
queries := e.Queries
if len(queries) > 3 {
queries = queries[:3]
}
parts := make([]string, len(queries))
for i, q := range queries {
parts[i] = "\"" + q + "\""
}
m.viewport.addInfo("Searching: " + strings.Join(parts, ", "))
}
case models.SearchDocumentsEvent:
count := len(e.Documents)
suffix := "s"
if count == 1 {
suffix = ""
}
m.viewport.addInfo("Found " + strconv.Itoa(count) + " document" + suffix)
case models.ReasoningStartEvent:
m.viewport.addInfo("Thinking…")
case models.ReasoningDeltaEvent:
// We don't display reasoning text, just the indicator
case models.ReasoningDoneEvent:
// No-op
case models.CitationEvent:
m.citations[e.CitationNumber] = e.DocumentID
case models.ToolStartEvent:
m.viewport.addInfo("Using " + e.ToolName + "…")
case models.ResearchAgentStartEvent:
m.viewport.addInfo("Researching: " + e.ResearchTask)
case models.DeepResearchPlanDeltaEvent:
m.viewport.appendToken(e.Content)
case models.IntermediateReportDeltaEvent:
m.viewport.appendToken(e.Content)
case models.StopEvent:
return m.finishStream(nil)
case models.ErrorEvent:
m.viewport.addError(e.Error)
return m.finishStream(nil)
}
return m, api.WaitForStreamEvent(m.streamCh)
}
func (m Model) handleStreamDone(msg api.StreamDoneMsg) (tea.Model, tea.Cmd) {
// Ignore if already cancelled
if !m.isStreaming {
return m, nil
}
return m.finishStream(msg.Err)
}
func (m Model) finishStream(err error) (Model, tea.Cmd) {
m.viewport.finishAgent()
if m.agentStarted && len(m.citations) > 0 {
m.viewport.addCitations(m.citations)
}
m.isStreaming = false
m.agentStarted = false
m.status.setStreaming(false)
if m.streamCancel != nil {
m.streamCancel()
}
m.streamCancel = nil
m.streamCh = nil
// Auto-rename new sessions
if m.needsRename && m.chatSessionID != nil {
m.needsRename = false
sessionID := *m.chatSessionID
client := m.client
go func() {
_, _ = client.RenameChatSession(sessionID, nil)
}()
}
return m, nil
}
func (m Model) handleInitDone(msg InitDoneMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addWarning("Could not load agents. Using default.")
} else {
m.agents = msg.Agents
for _, p := range m.agents {
if p.ID == m.agentID {
m.agentName = p.Name
break
}
}
}
m.status.setServer(m.config.ServerURL)
m.status.setAgent(m.agentName)
return m, nil
}
func (m Model) handleAgentsLoaded(msg AgentsLoadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load agents: " + msg.Err.Error())
return m, nil
}
m.agents = msg.Agents
if len(m.agents) == 0 {
m.viewport.addInfo("No agents available.")
return m, nil
}
m.viewport.addInfo("Select an agent (Enter to select, Esc to cancel):")
var items []pickerItem
for _, p := range m.agents {
label := fmt.Sprintf("%d: %s", p.ID, p.Name)
if p.ID == m.agentID {
label += " *"
}
desc := p.Description
if len(desc) > 50 {
desc = desc[:50] + "..."
}
if desc != "" {
label += " - " + desc
}
items = append(items, pickerItem{
id: strconv.Itoa(p.ID),
label: label,
})
}
m.viewport.showPicker(pickerAgent, items)
return m, nil
}
func (m Model) handleSessionsLoaded(msg SessionsLoadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load sessions: " + msg.Err.Error())
return m, nil
}
if len(msg.Sessions) == 0 {
m.viewport.addInfo("No previous sessions found.")
return m, nil
}
m.viewport.addInfo("Select a session to resume (Enter to select, Esc to cancel):")
var items []pickerItem
for i, s := range msg.Sessions {
if i >= 15 {
break
}
name := "Untitled"
if s.Name != nil && *s.Name != "" {
name = *s.Name
}
sid := s.ID
if len(sid) > 8 {
sid = sid[:8]
}
items = append(items, pickerItem{
id: s.ID,
label: sid + " " + name + " (" + s.Created + ")",
})
}
m.viewport.showPicker(pickerSession, items)
return m, nil
}
func (m Model) handleSessionResumed(msg SessionResumedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load session: " + msg.Err.Error())
return m, nil
}
// Cancel any in-progress stream before replacing the session
if m.isStreaming {
m, _ = m.cancelStream()
}
detail := msg.Detail
m.chatSessionID = &detail.ChatSessionID
m.viewport.clearDisplay()
m.status.setSession(detail.ChatSessionID)
if detail.AgentName != nil {
m.agentName = *detail.AgentName
m.status.setAgent(*detail.AgentName)
}
if detail.AgentID != nil {
m.agentID = *detail.AgentID
}
// Replay messages
for _, chatMsg := range detail.Messages {
switch chatMsg.MessageType {
case "user":
m.viewport.addUserMessage(chatMsg.Message)
case "assistant":
m.viewport.startAgent()
m.viewport.appendToken(chatMsg.Message)
m.viewport.finishAgent()
}
}
// Set parent to last message
if len(detail.Messages) > 0 {
lastID := detail.Messages[len(detail.Messages)-1].MessageID
m.parentMessageID = &lastID
}
desc := "Untitled"
if detail.Description != nil && *detail.Description != "" {
desc = *detail.Description
}
m.viewport.addInfo("Resumed session: " + desc)
return m, nil
}
func (m Model) handleFileUploaded(msg FileUploadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Upload failed: " + msg.Err.Error())
return m, nil
}
m.attachedFiles = append(m.attachedFiles, *msg.Descriptor)
m.input.addFile(msg.FileName)
m.viewport.addInfo("Attached: " + msg.FileName)
return m, nil
}
type inputReadyMsg struct{}
type resetQuitMsg struct{}

View File

@@ -1,205 +0,0 @@
package tui
import (
"fmt"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/util"
)
// handleSlashCommand dispatches slash commands and returns updated model + cmd.
func handleSlashCommand(m Model, text string) (Model, tea.Cmd) {
parts := strings.SplitN(text, " ", 2)
command := strings.ToLower(parts[0])
arg := ""
if len(parts) > 1 {
arg = parts[1]
}
switch command {
case "/help":
m.viewport.addInfo(helpText)
return m, nil
case "/new":
return cmdNew(m)
case "/agent":
if arg != "" {
return cmdSelectAgent(m, arg)
}
return cmdShowAgents(m)
case "/attach":
return cmdAttach(m, arg)
case "/sessions", "/resume":
if strings.TrimSpace(arg) != "" {
return cmdResume(m, arg)
}
return cmdSessions(m)
case "/configure":
m.viewport.addInfo("Run 'onyx-cli configure' to change connection settings.")
return m, nil
case "/clear":
return cmdNew(m)
case "/connectors":
url := m.config.ServerURL + "/admin/indexing/status"
if util.OpenBrowser(url) {
m.viewport.addInfo("Opened " + url + " in browser")
} else {
m.viewport.addWarning("Failed to open browser. Visit: " + url)
}
return m, nil
case "/settings":
url := m.config.ServerURL + "/app/settings/general"
if util.OpenBrowser(url) {
m.viewport.addInfo("Opened " + url + " in browser")
} else {
m.viewport.addWarning("Failed to open browser. Visit: " + url)
}
return m, nil
case "/quit":
return m, tea.Quit
default:
m.viewport.addWarning(fmt.Sprintf("Unknown command: %s. Type /help for available commands.", command))
return m, nil
}
}
func cmdNew(m Model) (Model, tea.Cmd) {
if m.isStreaming {
m, _ = m.cancelStream()
}
m.chatSessionID = nil
parentID := -1
m.parentMessageID = &parentID
m.needsRename = false
m.citations = nil
m.viewport.clearAll()
// Re-add splash as a scrollable entry
viewportHeight := m.viewportHeight()
if viewportHeight < 1 {
viewportHeight = m.height
}
m.viewport.addSplash(viewportHeight)
m.status.setSession("")
return m, nil
}
func cmdShowAgents(m Model) (Model, tea.Cmd) {
m.viewport.addInfo("Loading agents...")
client := m.client
return m, func() tea.Msg {
agents, err := client.ListAgents()
return AgentsLoadedMsg{Agents: agents, Err: err}
}
}
func cmdSelectAgent(m Model, idStr string) (Model, tea.Cmd) {
pid, err := strconv.Atoi(strings.TrimSpace(idStr))
if err != nil {
m.viewport.addWarning("Invalid agent ID. Use a number.")
return m, nil
}
var target *models.AgentSummary
for i := range m.agents {
if m.agents[i].ID == pid {
target = &m.agents[i]
break
}
}
if target == nil {
m.viewport.addWarning(fmt.Sprintf("Agent %d not found. Use /agent to see available agents.", pid))
return m, nil
}
m.agentID = target.ID
m.agentName = target.Name
m.status.setAgent(target.Name)
m.viewport.addInfo("Switched to agent: " + target.Name)
// Save preference
m.config.DefaultAgentID = target.ID
_ = config.Save(m.config)
return m, nil
}
func cmdAttach(m Model, pathStr string) (Model, tea.Cmd) {
if pathStr == "" {
m.viewport.addWarning("Usage: /attach <file_path>")
return m, nil
}
m.viewport.addInfo("Uploading " + pathStr + "...")
client := m.client
return m, func() tea.Msg {
fd, err := client.UploadFile(pathStr)
if err != nil {
return FileUploadedMsg{Err: err, FileName: pathStr}
}
return FileUploadedMsg{Descriptor: fd, FileName: pathStr}
}
}
func cmdSessions(m Model) (Model, tea.Cmd) {
m.viewport.addInfo("Loading sessions...")
client := m.client
return m, func() tea.Msg {
sessions, err := client.ListChatSessions()
return SessionsLoadedMsg{Sessions: sessions, Err: err}
}
}
func cmdResume(m Model, sessionIDStr string) (Model, tea.Cmd) {
client := m.client
return m, func() tea.Msg {
// Try to find session by prefix match
sessions, err := client.ListChatSessions()
if err != nil {
return SessionResumedMsg{Err: err}
}
var targetID string
for _, s := range sessions {
if strings.HasPrefix(s.ID, sessionIDStr) {
targetID = s.ID
break
}
}
if targetID == "" {
// Try as full UUID
targetID = sessionIDStr
}
detail, err := client.GetChatSession(targetID)
if err != nil {
return SessionResumedMsg{Err: fmt.Errorf("session not found: %s", sessionIDStr)}
}
return SessionResumedMsg{Detail: detail}
}
}
// loadAgentsCmd returns a tea.Cmd that loads agents from the API.
func loadAgentsCmd(client *api.Client) tea.Cmd {
return func() tea.Msg {
agents, err := client.ListAgents()
return InitDoneMsg{Agents: agents, Err: err}
}
}

View File

@@ -1,24 +0,0 @@
package tui
const helpText = `Onyx CLI Commands
/help Show this help message
/new Start a new chat session
/agent List and switch agents
/attach <path> Attach a file to next message
/sessions Browse and resume previous sessions
/clear Clear the chat display
/configure Re-run connection setup
/connectors Open connectors page in browser
/settings Open Onyx settings in browser
/quit Exit Onyx CLI
Keyboard Shortcuts
Enter Send message
Escape Cancel current generation
Ctrl+O Toggle source citations
Ctrl+D Quit (press twice)
Scroll Up/Down Mouse wheel or Shift+Up/Down
Page Up/Down Scroll half page
`

View File

@@ -1,242 +0,0 @@
package tui
import (
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
// slashCommand defines a slash command with its description.
type slashCommand struct {
command string
description string
}
var slashCommands = []slashCommand{
{"/help", "Show help message"},
{"/new", "Start a new chat session"},
{"/agent", "List and switch agents"},
{"/attach", "Attach a file to next message"},
{"/sessions", "Browse and resume previous sessions"},
{"/clear", "Clear the chat display"},
{"/configure", "Re-run connection setup"},
{"/connectors", "Open connectors in browser"},
{"/settings", "Open settings in browser"},
{"/quit", "Exit Onyx CLI"},
}
// Commands that take arguments (filled in with trailing space on Tab/Enter).
var argCommands = map[string]bool{
"/attach": true,
}
// inputModel manages the text input and slash command menu.
type inputModel struct {
textInput textinput.Model
menuVisible bool
menuItems []slashCommand
menuIndex int
attachedFiles []string
}
func newInputModel() inputModel {
ti := textinput.New()
ti.Prompt = "" // We render our own prompt in viewInput()
ti.Placeholder = "Send a message…"
ti.CharLimit = 10000
// Don't focus here — focus after first WindowSizeMsg to avoid
// capturing terminal init escape sequences as input.
return inputModel{
textInput: ti,
}
}
func (m inputModel) update(msg tea.Msg) (inputModel, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
return m.handleKey(msg)
}
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m = m.updateMenu()
return m, cmd
}
func (m inputModel) handleKey(msg tea.KeyMsg) (inputModel, tea.Cmd) {
switch msg.Type {
case tea.KeyUp:
if m.menuVisible && m.menuIndex > 0 {
m.menuIndex--
return m, nil
}
case tea.KeyDown:
if m.menuVisible && m.menuIndex < len(m.menuItems)-1 {
m.menuIndex++
return m, nil
}
case tea.KeyTab:
if m.menuVisible && len(m.menuItems) > 0 {
cmd := m.menuItems[m.menuIndex].command
if argCommands[cmd] {
m.textInput.SetValue(cmd + " ")
m.textInput.SetCursor(len(cmd) + 1)
} else {
m.textInput.SetValue(cmd)
m.textInput.SetCursor(len(cmd))
}
m.menuVisible = false
return m, nil
}
case tea.KeyEnter:
if m.menuVisible && len(m.menuItems) > 0 {
cmd := m.menuItems[m.menuIndex].command
if argCommands[cmd] {
m.textInput.SetValue(cmd + " ")
m.textInput.SetCursor(len(cmd) + 1)
m.menuVisible = false
return m, nil
}
// Execute immediately
m.textInput.SetValue("")
m.menuVisible = false
return m, func() tea.Msg { return submitMsg{text: cmd} }
}
text := strings.TrimSpace(m.textInput.Value())
if text == "" {
return m, nil
}
// Check for file path (drag-and-drop)
if dropped := detectFileDrop(text); dropped != "" {
m.textInput.SetValue("")
return m, func() tea.Msg { return fileDropMsg{path: dropped} }
}
m.textInput.SetValue("")
m.menuVisible = false
return m, func() tea.Msg { return submitMsg{text: text} }
case tea.KeyEscape:
if m.menuVisible {
m.menuVisible = false
return m, nil
}
}
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m = m.updateMenu()
return m, cmd
}
func (m inputModel) updateMenu() inputModel {
val := strings.TrimSpace(m.textInput.Value())
if strings.HasPrefix(val, "/") && !strings.Contains(val, " ") {
needle := strings.ToLower(val)
var filtered []slashCommand
for _, sc := range slashCommands {
if strings.HasPrefix(sc.command, needle) {
filtered = append(filtered, sc)
}
}
if len(filtered) > 0 {
m.menuVisible = true
m.menuItems = filtered
if m.menuIndex >= len(filtered) {
m.menuIndex = 0
}
} else {
m.menuVisible = false
}
} else {
m.menuVisible = false
}
return m
}
func (m *inputModel) addFile(name string) {
m.attachedFiles = append(m.attachedFiles, name)
}
func (m *inputModel) clearFiles() {
m.attachedFiles = nil
}
// submitMsg is sent when user submits text.
type submitMsg struct {
text string
}
// fileDropMsg is sent when a file path is detected.
type fileDropMsg struct {
path string
}
// detectFileDrop checks if the text looks like a file path.
func detectFileDrop(text string) string {
cleaned := strings.Trim(text, "'\"")
if cleaned == "" {
return ""
}
// Only treat as a file drop if it looks explicitly path-like
if !strings.HasPrefix(cleaned, "/") && !strings.HasPrefix(cleaned, "~") &&
!strings.HasPrefix(cleaned, "./") && !strings.HasPrefix(cleaned, "../") {
return ""
}
// Expand ~ to home dir
if strings.HasPrefix(cleaned, "~") {
home, err := os.UserHomeDir()
if err == nil {
cleaned = filepath.Join(home, cleaned[1:])
}
}
abs, err := filepath.Abs(cleaned)
if err != nil {
return ""
}
info, err := os.Stat(abs)
if err != nil {
return ""
}
if info.IsDir() {
return ""
}
return abs
}
// viewMenu renders the slash command menu.
func (m inputModel) viewMenu(width int) string {
if !m.menuVisible || len(m.menuItems) == 0 {
return ""
}
var lines []string
for i, item := range m.menuItems {
prefix := " "
if i == m.menuIndex {
prefix = "> "
}
line := prefix + item.command + " " + statusMsgStyle.Render(item.description)
lines = append(lines, line)
}
return strings.Join(lines, "\n")
}
// viewInput renders the input line with prompt and optional file badges.
func (m inputModel) viewInput() string {
var parts []string
if len(m.attachedFiles) > 0 {
badges := strings.Join(m.attachedFiles, "] [")
parts = append(parts, statusMsgStyle.Render("Attached: ["+badges+"]"))
}
parts = append(parts, inputPrompt+m.textInput.View())
return strings.Join(parts, "\n")
}

View File

@@ -1,36 +0,0 @@
package tui
import (
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// InitDoneMsg signals that async initialization is complete.
type InitDoneMsg struct {
Agents []models.AgentSummary
Err error
}
// SessionsLoadedMsg carries loaded chat sessions.
type SessionsLoadedMsg struct {
Sessions []models.ChatSessionDetails
Err error
}
// SessionResumedMsg carries a loaded session detail.
type SessionResumedMsg struct {
Detail *models.ChatSessionDetailResponse
Err error
}
// FileUploadedMsg carries an uploaded file descriptor.
type FileUploadedMsg struct {
Descriptor *models.FileDescriptorPayload
FileName string
Err error
}
// AgentsLoadedMsg carries freshly fetched agents from the API.
type AgentsLoadedMsg struct {
Agents []models.AgentSummary
Err error
}

View File

@@ -1,79 +0,0 @@
package tui
import (
"strings"
"github.com/charmbracelet/lipgloss"
)
const onyxLogo = ` ██████╗ ███╗ ██╗██╗ ██╗██╗ ██╗
██╔═══██╗████╗ ██║╚██╗ ██╔╝╚██╗██╔╝
██║ ██║██╔██╗ ██║ ╚████╔╝ ╚███╔╝
██║ ██║██║╚██╗██║ ╚██╔╝ ██╔██╗
╚██████╔╝██║ ╚████║ ██║ ██╔╝ ██╗
╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝`
const tagline = "Your terminal interface for Onyx"
const splashHint = "Type a message to begin · /help for commands"
// renderSplash renders the splash screen centered for the given dimensions.
func renderSplash(width, height int) string {
// Render the logo as a single block (don't center individual lines)
logo := splashStyle.Render(onyxLogo)
// Center tagline and hint relative to the logo block width
logoWidth := lipgloss.Width(logo)
tag := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
taglineStyle.Render(tagline),
)
hint := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
hintStyle.Render(splashHint),
)
block := lipgloss.JoinVertical(lipgloss.Left, logo, "", tag, hint)
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, block)
}
// RenderSplashOnboarding renders splash for the terminal onboarding screen.
func RenderSplashOnboarding(width, height int) string {
// Render the logo as a styled block, then center it as a unit
styledLogo := splashStyle.Render(onyxLogo)
logoWidth := lipgloss.Width(styledLogo)
logoLines := strings.Split(styledLogo, "\n")
logoHeight := len(logoLines)
contentHeight := logoHeight + 2 // logo + blank + tagline
topPad := (height - contentHeight) / 2
if topPad < 1 {
topPad = 1
}
// Center the entire logo block horizontally
blockPad := (width - logoWidth) / 2
if blockPad < 0 {
blockPad = 0
}
var b strings.Builder
for i := 0; i < topPad; i++ {
b.WriteByte('\n')
}
for _, line := range logoLines {
b.WriteString(strings.Repeat(" ", blockPad))
b.WriteString(line)
b.WriteByte('\n')
}
b.WriteByte('\n')
tagPad := (width - len(tagline)) / 2
if tagPad < 0 {
tagPad = 0
}
b.WriteString(strings.Repeat(" ", tagPad))
b.WriteString(taglineStyle.Render(tagline))
b.WriteByte('\n')
return b.String()
}

View File

@@ -1,60 +0,0 @@
package tui
import (
"strings"
"github.com/charmbracelet/lipgloss"
)
// statusBar manages the footer status display.
type statusBar struct {
agentName string
serverURL string
sessionID string
streaming bool
width int
}
func newStatusBar() statusBar {
return statusBar{
agentName: "Default",
}
}
func (s *statusBar) setAgent(name string) { s.agentName = name }
func (s *statusBar) setServer(url string) { s.serverURL = url }
func (s *statusBar) setSession(id string) {
if len(id) > 8 {
id = id[:8]
}
s.sessionID = id
}
func (s *statusBar) setStreaming(v bool) { s.streaming = v }
func (s *statusBar) setWidth(w int) { s.width = w }
func (s statusBar) view() string {
var leftParts []string
if s.serverURL != "" {
leftParts = append(leftParts, s.serverURL)
}
name := s.agentName
if name == "" {
name = "Default"
}
leftParts = append(leftParts, name)
left := statusBarStyle.Render(strings.Join(leftParts, " · "))
right := "Ctrl+D to quit"
if s.streaming {
right = "Esc to cancel"
}
rightRendered := statusBarStyle.Render(right)
// Fill space between left and right
gap := s.width - lipgloss.Width(left) - lipgloss.Width(rightRendered)
if gap < 1 {
gap = 1
}
return left + strings.Repeat(" ", gap) + rightRendered
}

View File

@@ -1,29 +0,0 @@
package tui
import "github.com/charmbracelet/lipgloss"
var (
// Colors
accentColor = lipgloss.Color("#6c8ebf")
dimColor = lipgloss.Color("#555577")
errorColor = lipgloss.Color("#ff5555")
splashColor = lipgloss.Color("#7C6AEF")
separatorColor = lipgloss.Color("#333355")
citationColor = lipgloss.Color("#666688")
// Styles
userPrefixStyle = lipgloss.NewStyle().Foreground(dimColor)
agentDot = lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("◉")
infoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#b0b0cc"))
dimInfoStyle = lipgloss.NewStyle().Foreground(dimColor)
statusMsgStyle = dimInfoStyle // used for slash menu descriptions, file badges
errorStyle = lipgloss.NewStyle().Foreground(errorColor).Bold(true)
warnStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
citationStyle = lipgloss.NewStyle().Foreground(citationColor)
statusBarStyle = lipgloss.NewStyle().Foreground(dimColor)
inputPrompt = lipgloss.NewStyle().Foreground(accentColor).Render(" ")
splashStyle = lipgloss.NewStyle().Foreground(splashColor).Bold(true)
taglineStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#A0A0A0"))
hintStyle = lipgloss.NewStyle().Foreground(dimColor)
)

View File

@@ -1,447 +0,0 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/glamour/styles"
"github.com/charmbracelet/lipgloss"
)
// entryKind is the type of chat entry.
type entryKind int
const (
entryUser entryKind = iota
entryAgent
entryInfo
entryError
entryCitation
)
// chatEntry is a single rendered entry in the chat history.
type chatEntry struct {
kind entryKind
content string // raw content (for agent: the markdown source)
rendered string // pre-rendered output
citations []string // citation lines (for citation entries)
}
// pickerKind distinguishes what the picker is selecting.
type pickerKind int
const (
pickerSession pickerKind = iota
pickerAgent
)
// pickerItem is a selectable item in the picker.
type pickerItem struct {
id string
label string
}
// viewport manages the chat display.
type viewport struct {
entries []chatEntry
width int
streaming bool
streamBuf string
showSources bool
renderer *glamour.TermRenderer
pickerItems []pickerItem
pickerActive bool
pickerIndex int
pickerType pickerKind
scrollOffset int // lines scrolled up from bottom (0 = pinned to bottom)
lastHeight int // viewport height from last render
}
// newMarkdownRenderer creates a Glamour renderer with zero left margin.
func newMarkdownRenderer(width int) *glamour.TermRenderer {
style := styles.DarkStyleConfig
zero := uint(0)
style.Document.Margin = &zero
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(style),
glamour.WithWordWrap(width-4),
)
return r
}
func newViewport(width int) *viewport {
return &viewport{
width: width,
renderer: newMarkdownRenderer(width),
}
}
func (v *viewport) addSplash(height int) {
splash := renderSplash(v.width, height)
v.entries = append(v.entries, chatEntry{
kind: entryInfo,
rendered: splash,
})
}
func (v *viewport) setWidth(w int) {
v.width = w
v.renderer = newMarkdownRenderer(w)
}
func (v *viewport) addUserMessage(msg string) {
rendered := "\n" + userPrefixStyle.Render(" ") + msg
v.entries = append(v.entries, chatEntry{
kind: entryUser,
content: msg,
rendered: rendered,
})
}
func (v *viewport) startAgent() {
v.streaming = true
v.streamBuf = ""
// Add a blank-line spacer entry before the agent message
v.entries = append(v.entries, chatEntry{kind: entryInfo, rendered: ""})
}
func (v *viewport) appendToken(token string) {
v.streamBuf += token
}
func (v *viewport) finishAgent() {
if v.streamBuf == "" {
v.streaming = false
// Remove the blank spacer entry added by startAgent()
if len(v.entries) > 0 && v.entries[len(v.entries)-1].kind == entryInfo && v.entries[len(v.entries)-1].rendered == "" {
v.entries = v.entries[:len(v.entries)-1]
}
return
}
// Render markdown with Glamour (zero left margin style)
rendered := v.renderMarkdown(v.streamBuf)
rendered = strings.TrimLeft(rendered, "\n")
rendered = strings.TrimRight(rendered, "\n")
lines := strings.Split(rendered, "\n")
// Prefix first line with dot, indent continuation lines
if len(lines) > 0 {
lines[0] = agentDot + " " + lines[0]
for i := 1; i < len(lines); i++ {
lines[i] = " " + lines[i]
}
}
rendered = strings.Join(lines, "\n")
v.entries = append(v.entries, chatEntry{
kind: entryAgent,
content: v.streamBuf,
rendered: rendered,
})
v.streaming = false
v.streamBuf = ""
}
func (v *viewport) renderMarkdown(md string) string {
if v.renderer == nil {
return md
}
out, err := v.renderer.Render(md)
if err != nil {
return md
}
return out
}
func (v *viewport) addInfo(msg string) {
rendered := infoStyle.Render("● " + msg)
v.entries = append(v.entries, chatEntry{
kind: entryInfo,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addWarning(msg string) {
rendered := warnStyle.Render("● " + msg)
v.entries = append(v.entries, chatEntry{
kind: entryError,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addError(msg string) {
rendered := errorStyle.Render("● Error: ") + msg
v.entries = append(v.entries, chatEntry{
kind: entryError,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addCitations(citations map[int]string) {
if len(citations) == 0 {
return
}
keys := make([]int, 0, len(citations))
for k := range citations {
keys = append(keys, k)
}
sort.Ints(keys)
var parts []string
for _, num := range keys {
parts = append(parts, fmt.Sprintf("[%d] %s", num, citations[num]))
}
text := fmt.Sprintf("Sources (%d): %s", len(citations), strings.Join(parts, " "))
var citLines []string
citLines = append(citLines, text)
v.entries = append(v.entries, chatEntry{
kind: entryCitation,
content: text,
rendered: citationStyle.Render("● "+text),
citations: citLines,
})
}
func (v *viewport) showPicker(kind pickerKind, items []pickerItem) {
v.pickerItems = items
v.pickerType = kind
v.pickerActive = true
v.pickerIndex = 0
}
func (v *viewport) maxScroll() int {
ms := v.totalLines() - v.lastHeight
if ms < 0 {
return 0
}
return ms
}
func (v *viewport) scrollUp(n int) {
v.scrollOffset += n
if ms := v.maxScroll(); v.scrollOffset > ms {
v.scrollOffset = ms
}
}
func (v *viewport) scrollDown(n int) {
v.scrollOffset -= n
if v.scrollOffset < 0 {
v.scrollOffset = 0
}
}
func (v *viewport) clearAll() {
v.entries = nil
v.streaming = false
v.streamBuf = ""
v.pickerItems = nil
v.pickerActive = false
v.scrollOffset = 0
}
func (v *viewport) clearDisplay() {
v.entries = nil
v.scrollOffset = 0
v.streaming = false
v.streamBuf = ""
}
// pickerTitle returns a title for the current picker kind.
func (v *viewport) pickerTitle() string {
switch v.pickerType {
case pickerAgent:
return "Select Agent"
case pickerSession:
return "Resume Session"
default:
return "Select"
}
}
// renderPicker renders the picker as a bordered overlay.
func (v *viewport) renderPicker(width, height int) string {
title := v.pickerTitle()
// Determine picker dimensions
maxItems := len(v.pickerItems)
panelWidth := width - 4
if panelWidth < 30 {
panelWidth = 30
}
if panelWidth > 70 {
panelWidth = 70
}
innerWidth := panelWidth - 4 // border + padding
// Visible window of items (scroll if too many)
maxVisible := height - 6 // room for border, title, hint
if maxVisible < 3 {
maxVisible = 3
}
if maxVisible > maxItems {
maxVisible = maxItems
}
// Calculate scroll window around current index
startIdx := 0
if v.pickerIndex >= maxVisible {
startIdx = v.pickerIndex - maxVisible + 1
}
endIdx := startIdx + maxVisible
if endIdx > maxItems {
endIdx = maxItems
startIdx = endIdx - maxVisible
if startIdx < 0 {
startIdx = 0
}
}
var itemLines []string
for i := startIdx; i < endIdx; i++ {
item := v.pickerItems[i]
label := item.label
labelRunes := []rune(label)
if len(labelRunes) > innerWidth-4 {
label = string(labelRunes[:innerWidth-7]) + "..."
}
if i == v.pickerIndex {
line := lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("> " + label)
itemLines = append(itemLines, line)
} else {
itemLines = append(itemLines, " "+label)
}
}
hint := lipgloss.NewStyle().Foreground(dimColor).Render("↑↓ navigate • enter select • esc cancel")
body := strings.Join(itemLines, "\n") + "\n\n" + hint
panel := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor).
Padding(1, 2).
Width(panelWidth).
Render(body)
titleRendered := lipgloss.NewStyle().
Foreground(accentColor).
Bold(true).
Render(" " + title + " ")
// Build top border manually to avoid ANSI-corrupted rune slicing.
// panelWidth+2 accounts for the left and right border characters.
borderColor := lipgloss.NewStyle().Foreground(accentColor)
titleWidth := lipgloss.Width(titleRendered)
rightDashes := panelWidth + 2 - 3 - titleWidth // total - "╭─" - "╮" - title
if rightDashes < 0 {
rightDashes = 0
}
topBorder := borderColor.Render("╭─") + titleRendered +
borderColor.Render(strings.Repeat("─", rightDashes)+"╮")
panelLines := strings.Split(panel, "\n")
if len(panelLines) > 0 {
panelLines[0] = topBorder
}
panel = strings.Join(panelLines, "\n")
// Center the panel in the viewport
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, panel)
}
// totalLines computes the total number of rendered content lines.
func (v *viewport) totalLines() int {
var lines []string
for _, e := range v.entries {
if e.kind == entryCitation && !v.showSources {
continue
}
lines = append(lines, e.rendered)
}
if v.streaming && v.streamBuf != "" {
bufLines := strings.Split(v.streamBuf, "\n")
if len(bufLines) > 0 {
bufLines[0] = agentDot + " " + bufLines[0]
for i := 1; i < len(bufLines); i++ {
bufLines[i] = " " + bufLines[i]
}
}
lines = append(lines, strings.Join(bufLines, "\n"))
} else if v.streaming {
lines = append(lines, agentDot+" ")
}
content := strings.Join(lines, "\n")
return len(strings.Split(content, "\n"))
}
// view renders the full viewport content.
func (v *viewport) view(height int) string {
// If picker is active, render it as an overlay
if v.pickerActive && len(v.pickerItems) > 0 {
return v.renderPicker(v.width, height)
}
var lines []string
for _, e := range v.entries {
if e.kind == entryCitation && !v.showSources {
continue
}
lines = append(lines, e.rendered)
}
// Streaming buffer (plain text, not markdown)
if v.streaming && v.streamBuf != "" {
bufLines := strings.Split(v.streamBuf, "\n")
if len(bufLines) > 0 {
bufLines[0] = agentDot + " " + bufLines[0]
for i := 1; i < len(bufLines); i++ {
bufLines[i] = " " + bufLines[i]
}
}
lines = append(lines, strings.Join(bufLines, "\n"))
} else if v.streaming {
lines = append(lines, agentDot+" ")
}
content := strings.Join(lines, "\n")
contentLines := strings.Split(content, "\n")
total := len(contentLines)
v.lastHeight = height
maxScroll := total - height
if maxScroll < 0 {
maxScroll = 0
}
scrollOffset := v.scrollOffset
if scrollOffset > maxScroll {
scrollOffset = maxScroll
}
if total <= height {
// Content fits — pad with empty lines at top to push content down
padding := make([]string, height-total)
for i := range padding {
padding[i] = ""
}
contentLines = append(padding, contentLines...)
} else {
// Show a window: end is (total - scrollOffset), start is (end - height)
end := total - scrollOffset
start := end - height
if start < 0 {
start = 0
}
contentLines = contentLines[start:end]
}
return strings.Join(contentLines, "\n")
}

View File

@@ -1,264 +0,0 @@
package tui
import (
"regexp"
"strings"
"testing"
)
// stripANSI removes ANSI escape sequences for test comparisons.
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
func stripANSI(s string) string {
return ansiRegex.ReplaceAllString(s, "")
}
func TestAddUserMessage(t *testing.T) {
v := newViewport(80)
v.addUserMessage("hello world")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryUser {
t.Errorf("expected entryUser, got %d", e.kind)
}
if e.content != "hello world" {
t.Errorf("expected content 'hello world', got %q", e.content)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "") {
t.Errorf("expected rendered to contain , got %q", plain)
}
if !strings.Contains(plain, "hello world") {
t.Errorf("expected rendered to contain message text, got %q", plain)
}
}
func TestStartAndFinishAgent(t *testing.T) {
v := newViewport(80)
v.startAgent()
if !v.streaming {
t.Error("expected streaming to be true after startAgent")
}
if len(v.entries) != 1 {
t.Fatalf("expected 1 spacer entry, got %d", len(v.entries))
}
if v.entries[0].rendered != "" {
t.Errorf("expected empty spacer, got %q", v.entries[0].rendered)
}
v.appendToken("Hello ")
v.appendToken("world")
if v.streamBuf != "Hello world" {
t.Errorf("expected streamBuf 'Hello world', got %q", v.streamBuf)
}
v.finishAgent()
if v.streaming {
t.Error("expected streaming to be false after finishAgent")
}
if v.streamBuf != "" {
t.Errorf("expected empty streamBuf after finish, got %q", v.streamBuf)
}
if len(v.entries) != 2 {
t.Fatalf("expected 2 entries (spacer + agent), got %d", len(v.entries))
}
e := v.entries[1]
if e.kind != entryAgent {
t.Errorf("expected entryAgent, got %d", e.kind)
}
if e.content != "Hello world" {
t.Errorf("expected content 'Hello world', got %q", e.content)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "Hello world") {
t.Errorf("expected rendered to contain message text, got %q", plain)
}
}
func TestFinishAgentNoPadding(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.appendToken("Test message")
v.finishAgent()
e := v.entries[1]
// First line should not start with plain spaces (ANSI codes are OK)
plain := stripANSI(e.rendered)
lines := strings.Split(plain, "\n")
if strings.HasPrefix(lines[0], " ") {
t.Errorf("first line should not start with spaces, got %q", lines[0])
}
}
func TestFinishAgentMultiline(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.appendToken("Line one\n\nLine three")
v.finishAgent()
e := v.entries[1]
plain := stripANSI(e.rendered)
// Glamour may merge or reformat lines; just check content is present
if !strings.Contains(plain, "Line one") {
t.Errorf("expected 'Line one' in rendered, got %q", plain)
}
if !strings.Contains(plain, "Line three") {
t.Errorf("expected 'Line three' in rendered, got %q", plain)
}
}
func TestFinishAgentEmpty(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.finishAgent()
if v.streaming {
t.Error("expected streaming to be false")
}
if len(v.entries) != 0 {
t.Errorf("expected 0 entries (spacer removed), got %d", len(v.entries))
}
}
func TestAddInfo(t *testing.T) {
v := newViewport(80)
v.addInfo("test info")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryInfo {
t.Errorf("expected entryInfo, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if strings.HasPrefix(plain, " ") {
t.Errorf("info should not have leading spaces, got %q", plain)
}
}
func TestAddError(t *testing.T) {
v := newViewport(80)
v.addError("something broke")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryError {
t.Errorf("expected entryError, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "something broke") {
t.Errorf("expected error message in rendered, got %q", plain)
}
}
func TestAddCitations(t *testing.T) {
v := newViewport(80)
v.addCitations(map[int]string{1: "doc-a", 2: "doc-b"})
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryCitation {
t.Errorf("expected entryCitation, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "Sources (2)") {
t.Errorf("expected sources count in rendered, got %q", plain)
}
if strings.HasPrefix(plain, " ") {
t.Errorf("citation should not have leading spaces, got %q", plain)
}
}
func TestAddCitationsEmpty(t *testing.T) {
v := newViewport(80)
v.addCitations(map[int]string{})
if len(v.entries) != 0 {
t.Errorf("expected no entries for empty citations, got %d", len(v.entries))
}
}
func TestCitationVisibility(t *testing.T) {
v := newViewport(80)
v.addInfo("hello")
v.addCitations(map[int]string{1: "doc"})
v.showSources = false
view := v.view(20)
plain := stripANSI(view)
if strings.Contains(plain, "Sources") {
t.Error("expected citations hidden when showSources=false")
}
v.showSources = true
view = v.view(20)
plain = stripANSI(view)
if !strings.Contains(plain, "Sources") {
t.Error("expected citations visible when showSources=true")
}
}
func TestClearAll(t *testing.T) {
v := newViewport(80)
v.addUserMessage("test")
v.startAgent()
v.appendToken("response")
v.clearAll()
if len(v.entries) != 0 {
t.Errorf("expected no entries after clearAll, got %d", len(v.entries))
}
if v.streaming {
t.Error("expected streaming=false after clearAll")
}
if v.streamBuf != "" {
t.Errorf("expected empty streamBuf after clearAll, got %q", v.streamBuf)
}
}
func TestClearDisplay(t *testing.T) {
v := newViewport(80)
v.addUserMessage("test")
v.clearDisplay()
if len(v.entries) != 0 {
t.Errorf("expected no entries after clearDisplay, got %d", len(v.entries))
}
}
func TestViewPadsShortContent(t *testing.T) {
v := newViewport(80)
v.addInfo("hello")
view := v.view(10)
lines := strings.Split(view, "\n")
if len(lines) != 10 {
t.Errorf("expected 10 lines (padded), got %d", len(lines))
}
}
func TestViewTruncatesTallContent(t *testing.T) {
v := newViewport(80)
for i := 0; i < 20; i++ {
v.addInfo("line")
}
view := v.view(5)
lines := strings.Split(view, "\n")
if len(lines) != 5 {
t.Errorf("expected 5 lines (truncated), got %d", len(lines))
}
}

View File

@@ -1,29 +0,0 @@
// Package util provides shared utility functions.
package util
import (
"os/exec"
"runtime"
)
// OpenBrowser opens the given URL in the user's default browser.
// Returns true if the browser was launched successfully.
func OpenBrowser(url string) bool {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "linux":
cmd = exec.Command("xdg-open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
}
if cmd != nil {
if err := cmd.Start(); err == nil {
// Reap the child process to avoid zombies.
go func() { _ = cmd.Wait() }()
return true
}
}
return false
}

View File

@@ -1,13 +0,0 @@
// Package util provides shared utilities for the Onyx CLI.
package util
import "github.com/charmbracelet/lipgloss"
// Shared text styles used across the CLI.
var (
BoldStyle = lipgloss.NewStyle().Bold(true)
DimStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#555577"))
GreenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00cc66")).Bold(true)
RedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff5555")).Bold(true)
YellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
)

View File

@@ -1,23 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/onyx-dot-app/onyx/cli/cmd"
)
var (
version = "dev"
commit = "none"
)
func main() {
cmd.Version = version
cmd.Commit = commit
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}

View File

@@ -138,6 +138,7 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
- MULTI_TENANT=true
- LOG_LEVEL=DEBUG

View File

@@ -52,6 +52,7 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -65,6 +65,7 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -70,6 +70,7 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

Some files were not shown because too many files have changed in this diff Show More