Compare commits

..

23 Commits

Author SHA1 Message Date
rohoswagger
6c6f3cebd8 fix(cli): add JSON type discriminator and URL validation in onboarding
Wrap --json events with {"type": "...", "event": {...}} so consumers
can distinguish event types without inspecting payload fields.

Validate server URL scheme (http/https) during onboarding setup.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 16:28:27 -08:00
rohoswagger
25538f7b61 feat(cli): add validate-config command, fix errcheck lint violations
- Add `onyx-cli validate-config` to check config and test connection
- Fix all errcheck violations caught by expanded golangci-lint hook
- Change default server URL to https://cloud.onyx.app

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 16:11:30 -08:00
rohoswagger
651f70c8c0 fix(cli): address review feedback — rename to onyx-cli, remove legacy env, fix TUI bugs
- Rename binary from onyx to onyx-cli across root.go, README, SKILL.md, error messages
- Remove DANSWER_API_KEY legacy env var from config, tests, README, SKILL.md
- Change default server URL to https://cloud.onyx.app
- Fix file drop detection: require explicit path prefix (/, ~, ./, ../)
- Fix cmdNew hardcoded viewport height: use m.viewportHeight()
- Fix auto-scroll wiping user scroll position during streaming
- Fix clearDisplay not resetting streaming state
- Fix top border dash count off-by-one in picker
- Fix view() state mutation: move scroll clamping out of render path
- Fix arrow keys falling through to input at picker boundaries
- Cancel in-progress stream when session is resumed
- Improve SKILL.md config check to verify api_key is non-empty

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 15:32:31 -08:00
rohoswagger
5e2072dbef refactor(cli): address jmelahman review feedback
- Move version/commit vars to main.go with exported cmd.Version/Commit
  (matches ODS pattern)
- Replace init() with constructor functions (newChatCmd, newAskCmd, etc.)
  and explicit AddCommand in root.go
- TestConnection uses /api/me instead of /api/chat/get-user-chat-sessions
- /clear now starts a new chat session (not just viewport clear)
- Extract shared lipgloss styles to internal/util/styles.go
- API key onboarding URL points to /app/settings/accounts-access
  (personal access keys, no admin privilege required)
- Marshal error in ask --json returns error instead of continuing
- Update go.mod to go 1.26.0

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 15:00:29 -08:00
rohoswagger
0f92287880 fix(cli): address review feedback — PgUp/PgDown, picker title, ask, go.mod
- Extract viewportHeight() method so PgUp/PgDown scroll distance
  accounts for the dynamic bottom area (menu, file badges)
- Build picker top border manually to avoid ANSI-corrupted rune slicing
- Validate gotStop in JSON mode so incomplete streams fail loudly
- Run go mod tidy to mark golang.org/x/text as direct dependency

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 13:23:12 -08:00
rohoswagger
491c3bf906 fix(cli): avoid leaking API key in SKILL.md config check
Replace `cat ~/.config/onyx-cli/config.json` with a `test -s` check
so coding agents don't print the API key to stdout when verifying
configuration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 13:17:31 -08:00
roshan
1d898e343f Update .cursor/skills/onyx-cli/SKILL.md
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-03 13:16:33 -08:00
rohoswagger
0b129e24ea feat(cli): rename binary to onyx, add agents command and SKILL.md
Rename the CLI binary from `onyx-cli` to `onyx` so the command is
`onyx ask "..."` instead of `onyx-cli ask "..."`. The pip package
name (`onyx-cli`) and config directory (`~/.config/onyx-cli/`) are
unchanged.

Also adds `onyx agents` subcommand for listing available agents and
a SKILL.md for coding agents (Claude Code, Cursor, Codex) to use
the CLI as a tool.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 13:09:58 -08:00
rohoswagger
267ac5ac60 fix(cli): address PR review feedback — security, error handling, and robustness
- Config file permissions: 0o644 → 0o600 (owner-only, protects API key)
- API key input: use term.ReadPassword to hide keystrokes during onboarding
- Config Load(): warn on malformed JSON instead of silent fallback
- Config save failure in onboarding: now fatal instead of continuing
- scanner.Err() checked after stream loop to surface read errors
- Malformed NDJSON returns ErrorEvent instead of silent nil
- ask --json: ErrorEvent now causes non-zero exit code
- ask: channel close without StopEvent treated as unexpected error
- OpenBrowser returns bool so callers report launch failures
- Picker label truncation: rune-based to prevent UTF-8 corruption
- Picker title: replaces border runes instead of inserting (fixes width)
- Empty agent responses: spacer entry cleaned up
- chat.go: remove duplicate error logging (cobra already prints it)
- Fix loop variable shadowing in session resume handler
- prompt(): handle EOF with partial data

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 23:06:00 -08:00
rohoswagger
bf80211eae feat(cli): add stop generation on Escape/Ctrl+D and guard against stale events
- Escape and Ctrl+D during streaming now immediately cancel generation,
  render partial markdown, and show "Generation stopped."
- Ctrl+D during streaming cancels first; requires fresh double-press to quit
- Discard stale StreamEventMsg/StreamDoneMsg after cancellation
- Fix info message ordering so it appears after the agent response

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 20:56:15 -08:00
roshan
135385e57b Apply suggestions from code review
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-02 20:41:43 -08:00
rohoswagger
f06630bc1b refactor(cli): apply Go best practices and remove dead code
- Fix data race: capture chatSessionID value before goroutine launch
- Replace interface{} with any (Go 1.18+ idiom) throughout codebase
- Use typed FileDescriptorPayload instead of map[string]any for type safety
- Share http.Transport across Client for connection pooling
- Return error from TestConnection instead of (bool, string)
- Return errors from cobra RunE instead of calling os.Exit
- Handle json.Marshal error in ask command
- Replace deprecated strings.Title with golang.org/x/text/cases
- Replace deprecated tea.MouseWheelUp/Down with tea.MouseButtonWheelUp/Down
- Extract duplicated openBrowser into internal/util with zombie process fix
- Replace custom itoa with strconv.Itoa
- Use sorted map keys for citations instead of magic +100 bound
- Remove unused withTempConfig, addDimInfo, scrollToBottom
- Remove duplicate RenderSplashOnboarding call in onboarding
- Add context.Background() to newRequest via NewRequestWithContext
- Lowercase error string per Go convention (ST1005)
- Fix var alignment in ask.go
- Update README: --persona-id → --agent-id, /persona → /agent

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 20:35:06 -08:00
rohoswagger
4495df98cf feat(cli): add picker overlay for agents/sessions, improve message styles
- Agent and session selection now opens a centered bordered overlay
  with arrow navigation instead of inline chat messages
- /agent fetches agents fresh from API each time
- /resume without args opens session picker (same as /sessions)
- Merge /sessions and /resume into one flow, remove duplicate menu entry
- Add tiered message styles: info (visible), warning (yellow), error (red)
- Remove /agent and /resume from argCommands so they execute immediately

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 19:58:37 -08:00
rohoswagger
0124937aa8 refactor(cli): rename persona/assistant to agent throughout codebase
Standardize terminology to "agent" in all Go identifiers, user-facing
strings, slash commands, CLI flags, and comments. API wire format
(endpoints, JSON field names) remains unchanged for compatibility.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 19:35:47 -08:00
rohoswagger
aec2d24706 docs(cli): add scroll shortcuts to README
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 19:27:02 -08:00
rohoswagger
16ebb55362 feat(cli): rewrite CLI in Go with Bubble Tea TUI
Port the Onyx CLI from Python/Textual to Go/Bubble Tea for single-binary
distribution and better performance. Includes full feature parity:
config, streaming, markdown rendering, slash commands, session management,
file upload, scrolling viewport, and onboarding flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 19:22:42 -08:00
rohoswagger
ab6c11319e feat(cli): add splash screen, file drop, /clear command, and UX improvements
- Show Onyx ASCII art splash on empty/new chat screens
- Auto-attach files on drag-and-drop (intercept paste events)
- Add /clear command to clear chat display
- Fix /settings URL to /app/settings/general
- Make Ctrl+D, Escape, Ctrl+O priority bindings (work regardless of focus)
- Update README to use uv instead of pip

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 17:43:22 -08:00
rohoswagger
05f5b96964 feat(cli): add slash command menu, session naming, interactive session picker, inline assistant prefix, and citation toggle
- Add slash command dropdown that appears when typing / with filtering, arrow key navigation, Tab/Enter selection
- Auto-name new chat sessions via backend LLM rename API after first message exchange
- Replace static /sessions list with interactive OptionList picker (arrow keys + Enter to resume, Esc to cancel)
- Change AssistantMessage to Horizontal layout with inline dot prefix instead of separate line
- Hide citation sources by default, add Ctrl+O keybinding to toggle visibility
- Fix InputArea/StatusBar overlap by removing dock:bottom from InputArea
- Simplify UserMessage to use dimmed prefix style instead of bordered panel
- Fix pre-existing ruff F541 warnings (extraneous f-string prefixes)
- Update tests for new AssistantMessage container compose cycle and CitationBlock widget

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 16:41:18 -08:00
rohoswagger
f525aa175b feat(cli): move connection info to status bar, add bottom divider
Move "Connected to <url> · Assistant" from chat display to the status
bar below the input. Add a second divider line below the input row
for visual framing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 00:08:41 +00:00
rohoswagger
4ba6e5f735 fix(cli): fix input focus blocked by slow API init, add UI tests
The on_mount handler was awaiting list_personas() which blocked the
entire Textual event loop — focus never reached ChatInput until the
HTTP call completed. Now focus is set immediately and API init runs
via run_worker() in the background.

Adds 28 Textual pilot tests covering focus, typing, message submission,
chat display, streaming, status bar, and file badges.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 22:42:43 +00:00
rohoswagger
992ad3b8d4 refactor(cli): redesign TUI for minimal Claude Code-style aesthetic
Replace RichLog-based chat display with VerticalScroll + individual
message widgets (UserMessage, AssistantMessage, StatusMessage,
ErrorMessage) to eliminate full-history replay on each streaming token.
Remove Header widget, add prompt prefix and separator, simplify status
bar to assistant name + contextual hint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 22:24:57 +00:00
rohoswagger
a6404f8b3e chore(cli): add .gitignore, remove build artifacts from tracking
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 21:59:54 +00:00
rohoswagger
efc49c9f6b feat(cli): add onyx-cli terminal chat interface
Self-contained Python package (pip install onyx-cli) providing a TUI for
chatting with Onyx from the terminal. Communicates purely over HTTP with
zero imports from the backend.

Includes:
- Textual-based chat TUI with streaming markdown responses
- Rich terminal onboarding flow (server URL, API key, connection test)
- NDJSON stream parser for all backend packet types
- Slash commands (/help, /new, /persona, /attach, /sessions, /resume, etc.)
- File upload support
- One-shot mode (onyx-cli ask "question")
- Dual auth headers (Authorization + X-Onyx-Authorization) for proxy compat
- Smart connection diagnostics (detects AWS ALB blocks, proxy issues)
- Unit tests for stream parser (29 tests) and config (14 tests)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 21:59:37 +00:00
163 changed files with 6272 additions and 4053 deletions

View File

@@ -0,0 +1,153 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
The CLI is configured when `~/.config/onyx-cli/config.json` exists and contains an `api_key`. Check with:
```bash
test -s ~/.config/onyx-cli/config.json && echo "configured" || echo "not configured"
```
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: http://localhost:3000
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `http://localhost:3000`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
| Event Type | Description |
|------------|-------------|
| `MessageDeltaEvent` | Content token — concatenate all `content` fields for the full answer |
| `StopEvent` | Stream complete |
| `ErrorEvent` | Error with `error` message field |
| `SearchStartEvent` | Onyx started searching documents |
| `CitationEvent` | Source citation with `citation_number` and `document_id` |
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
|------|------|-------------|
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -114,10 +114,8 @@ jobs:
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:

View File

@@ -122,7 +122,7 @@ repos:
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
entry: bash -c "find . -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.

View File

@@ -20,7 +20,6 @@ from ee.onyx.server.enterprise_settings.store import (
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -118,38 +117,15 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:

View File

@@ -109,12 +109,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True

View File

@@ -33,7 +33,6 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -303,17 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
def _upsert(request: LLMProviderUpsertRequest) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, default_model, db_session)
update_default_provider(provider.id, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -331,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider, default_model_name)
_upsert(openai_provider)
# Create default image generation config using the OpenAI API key
try:
@@ -366,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider, default_model_name)
_upsert(anthropic_provider)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -397,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider, default_model_name)
_upsert(vertexai_provider)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -435,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider, default_model_name)
_upsert(openrouter_provider)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -202,6 +202,7 @@ def create_default_image_gen_config_from_api_key(
api_key=api_key,
api_base=None,
api_version=None,
default_model_name=model_name,
deployment_name=None,
is_public=True,
)

View File

@@ -213,29 +213,11 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider: LLMProviderModel | None = None
if llm_provider_upsert_request.id:
existing_llm_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if not existing_llm_provider:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} not found"
)
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider.name != llm_provider_upsert_request.name:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
)
else:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider:
raise ValueError(
f"LLM provider with name '{llm_provider_upsert_request.name}'"
" already exists"
)
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
db_session.add(existing_llm_provider)
@@ -256,7 +238,11 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -320,6 +306,15 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -493,22 +488,6 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -625,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
_update_default_model(
db_session,
provider_id,
model_name,
provider.default_model_name, # type: ignore[arg-type]
LLMModelFlowType.CHAT,
)
@@ -817,6 +805,12 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes

View File

@@ -97,6 +97,7 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -237,9 +233,12 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -269,7 +268,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.model,
model=test_llm_request.default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -304,7 +303,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderView]:
) -> list[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -329,15 +328,7 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return llm_provider_list
@admin_router.put("/provider")
@@ -353,44 +344,18 @@ def put_llm_provider(
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
existing_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -428,6 +393,22 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -457,8 +438,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -488,29 +469,28 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/default")
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
default_model_request: DefaultModel,
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
update_default_provider(provider_id=provider_id, db_session=db_session)
@admin_router.post("/default-vision")
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
default_model: DefaultModel,
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=default_model.provider_id,
vision_model=default_model.model_name,
db_session=db_session,
provider_id=provider_id, vision_model=vision_model, db_session=db_session
)
@@ -536,7 +516,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[VisionProviderResponse]:
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -565,13 +545,7 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return vision_provider_response
"""Endpoints for all"""
@@ -581,7 +555,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderDescriptor]:
) -> list[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -618,15 +592,7 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return accessible_providers
def get_valid_model_names_for_persona(
@@ -669,7 +635,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderDescriptor]:
) -> list[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -716,51 +682,7 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
# Get the default model and vision model for the persona
# TODO: Port persona's over to use ID
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text = DefaultModel.from_model_config(default_text_model)
default_vision = DefaultModel.from_model_config(default_vision_model)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider and can_user_access_llm_provider(
provider, user_group_ids, persona, is_admin=is_admin
):
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
return llm_provider_list
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,9 +1,5 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -25,22 +21,50 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
# There is no logic that requires sending the providers default vision model name
# We only send for the one that is actually the default
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
"""Find the default conversation model name for a provider.
Returns the model name if found, otherwise returns empty string.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
return model_config.name
return ""
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
"""Find the default vision model name for a provider.
Returns the model name if found, otherwise returns None.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
return model_config.name
return None
class TestLLMRequest(BaseModel):
# provider level
id: int | None = None
name: str | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -56,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -72,12 +99,24 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -91,17 +130,18 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -117,6 +157,8 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -138,6 +180,16 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -150,6 +202,10 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -369,38 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -19,7 +19,6 @@ class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
class Notification(BaseModel):
@@ -83,10 +82,6 @@ class Settings(BaseModel):
# Default Assistant settings
disable_default_assistant: bool | None = False
# Seat usage - populated by license enforcement when seat limit is exceeded
seat_count: int | None = None
used_seats: int | None = None
# OpenSearch migration
opensearch_indexing_enabled: bool = False

View File

@@ -25,7 +25,6 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -255,18 +254,14 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
provider_name = "DevEnvPresetOpenAI"
existing = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
model_req = LLMProviderUpsertRequest(
id=existing.id if existing else None,
name=provider_name,
name="DevEnvPresetOpenAI",
provider=LlmProviderNames.OPENAI,
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -278,9 +273,7 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
update_default_provider(provider.id, db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
update_default_provider(anthropic_provider.id, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -45,14 +44,15 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> LLMProviderView:
) -> None:
"""Helper to create a test LLM provider in the database."""
return upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -102,11 +102,17 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -146,11 +152,17 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -182,9 +194,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
_create_test_provider(db_session, provider_name, api_key=original_api_key)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -192,12 +202,17 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name, # Existing provider
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -231,9 +246,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
_create_test_provider(db_session, provider_name, api_key=original_api_key)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -241,12 +254,17 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name, # Existing provider
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -279,7 +297,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
provider = upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -287,6 +305,7 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -302,13 +321,18 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -344,11 +368,17 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model=model_name,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -412,6 +442,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -421,7 +452,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
update_default_provider(provider_1.id, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -441,6 +472,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -467,11 +499,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -480,9 +512,6 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -495,7 +524,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, provider_2_default_model, db_session)
update_default_provider(provider_2.id, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -567,6 +596,7 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -575,7 +605,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
update_default_provider(provider.id, db_session)
# Test should fail
with patch(

View File

@@ -49,6 +49,7 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -90,14 +91,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -124,16 +125,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -158,16 +159,14 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
_create_test_provider(db_session, provider_name, api_base=original_api_base)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -191,14 +190,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
view = _create_test_provider(db_session, provider_name, api_base=None)
_create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -224,16 +223,14 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
_create_test_provider(db_session, provider_name, api_base=original_api_base)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -262,14 +259,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -300,6 +297,7 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -324,7 +322,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -332,11 +330,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -364,15 +362,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
provider = _create_test_provider(db_session, provider_name)
_create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -401,7 +399,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -409,13 +407,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -440,17 +438,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -476,7 +474,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
provider = _create_test_provider(
_create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -484,10 +482,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -532,8 +530,14 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -542,10 +546,11 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -564,9 +569,14 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
id=provider.id,
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=False,
),
@@ -576,9 +586,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -606,13 +616,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
assert False, "Provider not found in the database"
assert db_provider.custom_config == custom_config, (
assert provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {db_provider.custom_config}"
f"but got {provider.custom_config}"
)
finally:
db_session.rollback()
@@ -632,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
view = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -654,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -695,7 +706,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider_name = LlmProviderNames.VERTEX_AI.value
provider = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -708,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -730,9 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
id=provider.id,
provider=provider_name,
model=default_model_name,
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -15,11 +15,9 @@ import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -137,6 +135,7 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -164,8 +163,13 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, expected_default_model, db_session)
update_default_provider(provider.id, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -234,6 +238,7 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -312,6 +317,7 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -320,13 +326,13 @@ class TestAutoModeSyncFeature:
)
# Verify initial state: all models are visible
initial_provider = fetch_existing_llm_provider(
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert initial_provider is not None
assert initial_provider.is_auto_mode is False
assert provider is not None
assert provider.is_auto_mode is False
for mc in initial_provider.model_configurations:
for mc in provider.model_configurations:
assert (
mc.is_visible is True
), f"Initial model '{mc.name}' should be visible"
@@ -338,12 +344,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=initial_provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -354,15 +360,15 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider_view = fetch_llm_provider_view(
provider_name=provider_name, db_session=db_session
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider_view is not None
assert provider_view.is_auto_mode is True
assert provider is not None
assert provider.is_auto_mode is True
# Build a map of model name -> visibility
model_visibility = {
mc.name: mc.is_visible for mc in provider_view.model_configurations
mc.name: mc.is_visible for mc in provider.model_configurations
}
# Models in auto mode config should be visible
@@ -382,6 +388,9 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -423,12 +432,8 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
@@ -530,6 +535,7 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -543,7 +549,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, provider_1_default_model, db_session)
update_default_provider(provider_1.id, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -557,6 +563,7 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -577,7 +584,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, provider_2_default_model, db_session)
update_default_provider(provider_2.id, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
@@ -637,6 +644,7 @@ class TestAutoModeMissingFlows:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -693,364 +701,3 @@ class TestAutoModeMissingFlows:
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
class TestAutoModeTransitionsAndResync:
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Disabling auto mode should preserve the current model list and
prevent future syncs from altering visibility.
Steps:
1. Create provider in auto mode — models synced from config.
2. Update provider to manual mode (is_auto_mode=False).
3. Verify all models remain with unchanged visibility.
4. Call sync_auto_mode_models with a *different* config.
5. Verify fetch_auto_mode_providers excludes this provider, so the
periodic task would never call sync on it.
"""
initial_config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create in auto mode
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=initial_config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility_before = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
# Step 2: Switch to manual mode
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
is_auto_mode=False,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
),
],
),
is_creation=False,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 3: Models unchanged
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
visibility_after = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_after == visibility_before
# Step 4-5: Provider excluded from auto mode queries
auto_providers = fetch_auto_mode_providers(db_session)
auto_provider_ids = {p.id for p in auto_providers}
assert provider.id not in auto_provider_ids
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_resync_adds_new_and_hides_removed_models(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the GitHub config changes between syncs, a subsequent sync
should add newly listed models and hide models that were removed.
Steps:
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
gpt-4-turbo added).
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
and visible.
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4-turbo"],
)
try:
# Step 1: Create with config v1
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Re-sync with config v2
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 3: Verify
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
# gpt-4o: still in config -> visible
assert visibility["gpt-4o"] is True
# gpt-4o-mini: removed from config -> hidden (not deleted)
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
assert visibility["gpt-4o-mini"] is False
# gpt-4-turbo: newly added -> visible
assert visibility["gpt-4-turbo"] is True
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_sync_is_idempotent(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Running sync twice with the same config should produce zero
changes on the second call."""
config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
# First explicit sync (may report changes if creation already synced)
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
# Snapshot state after first sync
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
snapshot = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
# Second sync — should be a no-op
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
assert (
changes == 0
), f"Expected 0 changes on idempotent re-sync, got {changes}"
# State should be identical
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
current = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
assert current == snapshot
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_default_model_hidden_when_removed_from_config(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the current default model is removed from the config, sync
should hide it. The default model flow row should still exist (it
points at the ModelConfiguration), but the model is no longer visible.
Steps:
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
2. Set gpt-4o as the global default.
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
fetch_default_llm_model still returns a result (the flow row persists).
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o-mini",
additional_models=[],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Set gpt-4o as global default
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
update_default_provider(provider.id, "gpt-4o", db_session)
default_before = fetch_default_llm_model(db_session)
assert default_before is not None
assert default_before.name == "gpt-4o"
# Step 3: Re-sync with config v2 (gpt-4o removed)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 4: Verify visibility
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
# but the model is hidden. fetch_default_llm_model filters on
# is_visible=True, so it should NOT return gpt-4o.
db_session.expire_all()
default_after = fetch_default_llm_model(db_session)
assert (
default_after is None or default_after.name != "gpt-4o"
), "Hidden model should not be returned as the default"
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -64,6 +64,7 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
update_default_provider(public_provider_id, db_session)
try:
# Create chat session

View File

@@ -42,6 +42,7 @@ def _create_llm_provider_and_model(
name=provider_name,
provider="openai",
api_key="test-api-key",
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,

View File

@@ -434,6 +434,7 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -447,7 +448,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, "gpt-4o", db_session)
update_default_provider(provider_view.id, db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""

View File

@@ -4,12 +4,10 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -34,6 +32,7 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -66,7 +65,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=default_model_name or "gpt-4o-mini",
default_model_name=response_data["default_model_name"],
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -76,19 +75,9 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
)
set_default_response.raise_for_status()
@@ -115,7 +104,7 @@ class LLMProviderManager:
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**p) for p in response.json()["providers"]]
return [LLMProviderView(**ug) for ug in response.json()]
@staticmethod
def verify(
@@ -124,11 +113,7 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -141,30 +126,11 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and (
default_model is None or default_model.model_name in model_names
)
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
default_text = response.json().get("default_text")
if default_text is None:
return None
return DefaultModel(**default_text)

View File

@@ -128,7 +128,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str | None = None
default_model_name: str
is_public: bool
is_auto_mode: bool = False
groups: list[int]

View File

@@ -42,10 +42,12 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,
"custom_config": None,
"fast_default_model_name": default_model,
"is_public": True,
"is_auto_mode": is_auto_mode,
"groups": [],
@@ -70,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json()["providers"]:
for provider in response.json():
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")
@@ -217,6 +219,15 @@ def test_auto_mode_provider_gets_synced_from_github_config(
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
default_model["name"] if isinstance(default_model, dict) else default_model
)
assert synced_provider["default_model_name"] == expected_default, (
f"Default model should be {expected_default}, "
f"got {synced_provider['default_model_name']}"
)
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None, # noqa: ARG001
@@ -262,3 +273,7 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
f"Manual mode provider models should not change. "
f"Initial: {initial_models}, Current: {current_models}"
)
assert (
updated_provider["default_model_name"] == custom_model
), f"Manual mode default model should remain {custom_model}"

View File

@@ -6,21 +6,20 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import LLMModelFlow
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_llm_for_persona
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
@@ -42,30 +41,24 @@ def _create_llm_provider(
is_public: bool,
is_default: bool,
) -> LLMProviderModel:
_provider = upsert_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name,
is_visible=True,
)
],
),
db_session=db_session,
provider = LLMProviderModel(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
)
if is_default:
update_default_provider(_provider.id, default_model_name, db_session)
provider = db_session.get(LLMProviderModel, _provider.id)
if not provider:
raise ValueError(f"Provider {name} not found")
db_session.add(provider)
db_session.flush()
return provider
@@ -277,6 +270,24 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()
@@ -310,19 +321,13 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert (
allowed_llm.config.model_name
== restricted_provider.model_configurations[0].name
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert (
fallback_llm.config.model_name
== default_provider.model_configurations[0].name
)
assert fallback_llm.config.model_name == default_provider.default_model_name
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -341,7 +346,6 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
name="public-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)
@@ -361,7 +365,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -376,7 +380,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()["providers"]
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
@@ -392,7 +396,6 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
name="default-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)

View File

@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
# Should succeed
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
# Should include the restricted provider since basic_user can access the persona
provider_names = [p["name"] for p in providers]
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
# Should succeed (persona_id=0 refers to default persona, which is public)
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
# Should NOT include the restricted provider since basic_user is not in group2
provider_names = [p["name"] for p in providers]
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
# Should succeed - admins can access any persona
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
# Should include the restricted provider
provider_names = [p["name"] for p in providers]
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
# Should succeed
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
# Should return the public provider
assert len(providers) > 0

View File

@@ -9,19 +9,6 @@ from redis.exceptions import RedisError
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
# Fields we assert on across all tests
_ASSERT_FIELDS = {
"application_status",
"ee_features_enabled",
"seat_count",
"used_seats",
}
def _pick(settings: Settings) -> dict:
"""Extract only the fields under test from a Settings object."""
return settings.model_dump(include=_ASSERT_FIELDS)
@pytest.fixture
def base_settings() -> Settings:
@@ -40,17 +27,17 @@ class TestApplyLicenseStatusToSettings:
def test_enforcement_disabled_enables_ee_features(
self, base_settings: Settings
) -> None:
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
If we're running the EE apply function, EE code was loaded via
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
"""
from ee.onyx.server.settings.api import apply_license_status_to_settings
assert base_settings.ee_features_enabled is False
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is True
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
@@ -59,60 +46,13 @@ class TestApplyLicenseStatusToSettings:
from ee.onyx.server.settings.api import apply_license_status_to_settings
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
assert result.ee_features_enabled is True
@pytest.mark.parametrize(
"license_status,used_seats,seats,expected",
"license_status,expected_app_status,expected_ee_enabled",
[
(
ApplicationStatus.GATED_ACCESS,
3,
10,
{
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
10,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.GRACE_PERIOD,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
],
)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -123,80 +63,25 @@ class TestApplyLicenseStatusToSettings:
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
license_status: ApplicationStatus,
used_seats: int,
seats: int,
expected: dict,
license_status: ApplicationStatus | None,
expected_app_status: ApplicationStatus,
expected_ee_enabled: bool,
base_settings: Settings,
) -> None:
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_metadata.used_seats = used_seats
mock_metadata.seats = seats
mock_get_metadata.return_value = mock_metadata
if license_status is None:
mock_get_metadata.return_value = None
else:
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == expected
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_seat_limit_exceeded_sets_status_and_counts(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.ACTIVE
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
"ee_features_enabled": True,
"seat_count": 10,
"used_seats": 15,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_expired_license_takes_precedence_over_seat_limit(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.GATED_ACCESS
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
assert result.application_status == expected_app_status
assert result.ee_features_enabled is expected_ee_enabled
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -220,12 +105,8 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
assert result.application_status == ApplicationStatus.GATED_ACCESS
assert result.ee_features_enabled is False
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -249,12 +130,8 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@@ -273,12 +150,8 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.side_effect = RedisError("Connection failed")
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
class TestSettingsDefaultEEDisabled:

View File

@@ -44,6 +44,7 @@ def _build_provider_view(
id=1,
name="test-provider",
provider=provider,
default_model_name="test-model",
model_configurations=[
ModelConfigurationView(
name="test-model",
@@ -61,6 +62,7 @@ def _build_provider_view(
groups=[],
personas=[],
deployment_name=None,
default_vision_model=None,
)

3
cli/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
onyx-cli
cli
onyx.cli

118
cli/README.md Normal file
View File

@@ -0,0 +1,118 @@
# Onyx CLI
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
## Installation
```shell
pip install onyx-cli
```
Or with uv:
```shell
uv pip install onyx-cli
```
## Setup
Run the interactive setup:
```shell
onyx-cli configure
```
This prompts for your Onyx server URL and API key, tests the connection, and saves config to `~/.config/onyx-cli/config.json`.
Environment variables override config file values:
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
## Usage
### Interactive chat (default)
```shell
onyx-cli
```
### One-shot question
```shell
onyx-cli ask "What is our company's PTO policy?"
onyx-cli ask --agent-id 5 "Summarize this topic"
onyx-cli ask --json "Hello"
```
| Flag | Description |
|------|-------------|
| `--agent-id <int>` | Agent ID to use (overrides default) |
| `--json` | Output raw NDJSON events instead of plain text |
### List agents
```shell
onyx-cli agents
onyx-cli agents --json
```
## Commands
| Command | Description |
|---------|-------------|
| `chat` | Launch the interactive chat TUI (default) |
| `ask` | Ask a one-shot question (non-interactive) |
| `agents` | List available agents |
| `configure` | Configure server URL and API key |
## Slash Commands (in TUI)
| Command | Description |
|---------|-------------|
| `/help` | Show help message |
| `/new` | Start a new chat session |
| `/agent` | List and switch agents |
| `/attach <path>` | Attach a file to next message |
| `/sessions` | List recent chat sessions |
| `/clear` | Clear the chat display |
| `/configure` | Re-run connection setup |
| `/connectors` | Open connectors in browser |
| `/settings` | Open settings in browser |
| `/quit` | Exit Onyx CLI |
## Keyboard Shortcuts
| Key | Action |
|-----|--------|
| `Enter` | Send message |
| `Escape` | Cancel current generation |
| `Ctrl+O` | Toggle source citations |
| `Ctrl+D` | Quit (press twice) |
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
| `Page Up` / `Page Down` | Scroll half page |
## Building from Source
Requires [Go 1.24+](https://go.dev/dl/).
```shell
cd cli
go build -o onyx-cli .
```
## Development
```shell
# Run tests
go test ./...
# Build
go build -o onyx-cli .
# Lint
staticcheck ./...
```

63
cli/cmd/agents.go Normal file
View File

@@ -0,0 +1,63 @@
package cmd
import (
"encoding/json"
"fmt"
"text/tabwriter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/spf13/cobra"
)
func newAgentsCmd() *cobra.Command {
var agentsJSON bool
cmd := &cobra.Command{
Use: "agents",
Short: "List available agents",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
}
client := api.NewClient(cfg)
agents, err := client.ListAgents()
if err != nil {
return fmt.Errorf("failed to list agents: %w", err)
}
if agentsJSON {
data, err := json.MarshalIndent(agents, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal agents: %w", err)
}
fmt.Println(string(data))
return nil
}
if len(agents) == 0 {
fmt.Println("No agents available.")
return nil
}
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "ID\tNAME\tDESCRIPTION")
for _, a := range agents {
desc := a.Description
if len(desc) > 60 {
desc = desc[:57] + "..."
}
_, _ = fmt.Fprintf(w, "%d\t%s\t%s\n", a.ID, a.Name, desc)
}
_ = w.Flush()
return nil
},
}
cmd.Flags().BoolVar(&agentsJSON, "json", false, "Output agents as JSON")
return cmd
}

103
cli/cmd/ask.go Normal file
View File

@@ -0,0 +1,103 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/spf13/cobra"
)
func newAskCmd() *cobra.Command {
var (
askAgentID int
askJSON bool
)
cmd := &cobra.Command{
Use: "ask [question]",
Short: "Ask a one-shot question (non-interactive)",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
}
question := args[0]
agentID := cfg.DefaultAgentID
if cmd.Flags().Changed("agent-id") {
agentID = askAgentID
}
client := api.NewClient(cfg)
parentID := -1
ch := client.SendMessageStream(
context.Background(),
question,
nil,
agentID,
&parentID,
nil,
)
var lastErr error
gotStop := false
for event := range ch {
if askJSON {
wrapped := struct {
Type string `json:"type"`
Event models.StreamEvent `json:"event"`
}{
Type: event.EventType(),
Event: event,
}
data, err := json.Marshal(wrapped)
if err != nil {
return fmt.Errorf("error marshaling event: %w", err)
}
fmt.Println(string(data))
if _, ok := event.(models.ErrorEvent); ok {
lastErr = fmt.Errorf("%s", event.(models.ErrorEvent).Error)
}
if _, ok := event.(models.StopEvent); ok {
gotStop = true
}
continue
}
switch e := event.(type) {
case models.MessageDeltaEvent:
fmt.Print(e.Content)
case models.ErrorEvent:
return fmt.Errorf("%s", e.Error)
case models.StopEvent:
fmt.Println()
return nil
}
}
if lastErr != nil {
return lastErr
}
if !gotStop {
if !askJSON {
fmt.Println()
}
return fmt.Errorf("stream ended unexpectedly")
}
if !askJSON {
fmt.Println()
}
return nil
},
}
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
// Suppress cobra's default error/usage on RunE errors
return cmd
}

33
cli/cmd/chat.go Normal file
View File

@@ -0,0 +1,33 @@
package cmd
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
)
func newChatCmd() *cobra.Command {
return &cobra.Command{
Use: "chat",
Short: "Launch the interactive chat TUI (default)",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
// First-run: onboarding
if !config.ConfigExists() || !cfg.IsConfigured() {
result := onboarding.Run(&cfg)
if result == nil {
return nil
}
cfg = *result
}
m := tui.NewModel(cfg)
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
_, err := p.Run()
return err
},
}
}

19
cli/cmd/configure.go Normal file
View File

@@ -0,0 +1,19 @@
package cmd
import (
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/spf13/cobra"
)
func newConfigureCmd() *cobra.Command {
return &cobra.Command{
Use: "configure",
Short: "Configure server URL and API key",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
onboarding.Run(&cfg)
return nil
},
}
}

40
cli/cmd/root.go Normal file
View File

@@ -0,0 +1,40 @@
// Package cmd implements Cobra CLI commands for the Onyx CLI.
package cmd
import "github.com/spf13/cobra"
// Version and Commit are set via ldflags at build time.
var (
Version string
Commit string
)
func fullVersion() string {
if Commit != "" && Commit != "none" && len(Commit) > 7 {
return Version + " (" + Commit[:7] + ")"
}
return Version
}
// Execute creates and runs the root command.
func Execute() error {
rootCmd := &cobra.Command{
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
Version: fullVersion(),
}
// Register subcommands
chatCmd := newChatCmd()
rootCmd.AddCommand(chatCmd)
rootCmd.AddCommand(newAskCmd())
rootCmd.AddCommand(newAgentsCmd())
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
// Default command is chat
rootCmd.RunE = chatCmd.RunE
return rootCmd.Execute()
}

41
cli/cmd/validate.go Normal file
View File

@@ -0,0 +1,41 @@
package cmd
import (
"fmt"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/spf13/cobra"
)
func newValidateConfigCmd() *cobra.Command {
return &cobra.Command{
Use: "validate-config",
Short: "Validate configuration and test server connection",
RunE: func(cmd *cobra.Command, args []string) error {
// Check config file
if !config.ConfigExists() {
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
}
cfg := config.Load()
// Check API key
if !cfg.IsConfigured() {
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
// Test connection
client := api.NewClient(cfg)
if err := client.TestConnection(); err != nil {
return fmt.Errorf("connection failed: %w", err)
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
return nil
},
}
}

45
cli/go.mod Normal file
View File

@@ -0,0 +1,45 @@
module github.com/onyx-dot-app/onyx/cli
go 1.26.0
require (
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.4
github.com/charmbracelet/glamour v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/spf13/cobra v1.9.1
golang.org/x/term v0.22.0
golang.org/x/text v0.34.0
)
require (
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yuin/goldmark v1.7.4 // indirect
github.com/yuin/goldmark-emoji v1.0.3 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.30.0 // indirect
)

94
cli/go.sum Normal file
View File

@@ -0,0 +1,94 @@
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

279
cli/internal/api/client.go Normal file
View File

@@ -0,0 +1,279 @@
// Package api provides the HTTP client for communicating with the Onyx server.
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// Client is the Onyx API client.
type Client struct {
baseURL string
apiKey string
httpClient *http.Client // default 30s timeout for quick requests
longHTTPClient *http.Client // 5min timeout for streaming/uploads
}
// NewClient creates a new API client from config.
func NewClient(cfg config.OnyxCliConfig) *Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
return &Client{
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
apiKey: cfg.APIKey,
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
},
longHTTPClient: &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
},
}
}
// UpdateConfig replaces the client's config.
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
c.apiKey = cfg.APIKey
}
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(context.Background(), method, c.baseURL+path, body)
if err != nil {
return nil, err
}
if c.apiKey != "" {
bearer := "Bearer " + c.apiKey
req.Header.Set("Authorization", bearer)
req.Header.Set("X-Onyx-Authorization", bearer)
}
return req, nil
}
func (c *Client) doJSON(method, path string, reqBody any, result any) error {
var body io.Reader
if reqBody != nil {
data, err := json.Marshal(reqBody)
if err != nil {
return err
}
body = bytes.NewReader(data)
}
req, err := c.newRequest(method, path, body)
if err != nil {
return err
}
if reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
}
if result != nil {
return json.NewDecoder(resp.Body).Decode(result)
}
return nil
}
// TestConnection checks if the server is reachable and credentials are valid.
// Returns nil on success, or an error with a descriptive message on failure.
func (c *Client) TestConnection() error {
// Step 1: Basic reachability
req, err := c.newRequest("GET", "/", nil)
if err != nil {
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
}
_ = resp.Body.Close()
serverHeader := strings.ToLower(resp.Header.Get("Server"))
if resp.StatusCode == 403 {
if strings.Contains(serverHeader, "awselb") || strings.Contains(serverHeader, "amazons3") {
return fmt.Errorf("blocked by AWS load balancer (HTTP 403 on all requests).\n Your IP address may not be in the ALB's security group or WAF allowlist")
}
return fmt.Errorf("HTTP 403 on base URL — the server is blocking all traffic.\n This is likely a firewall, WAF, or IP allowlist restriction")
}
// Step 2: Authenticated check
req2, err := c.newRequest("GET", "/api/me", nil)
if err != nil {
return fmt.Errorf("server reachable but API error: %w", err)
}
resp2, err := c.longHTTPClient.Do(req2)
if err != nil {
return fmt.Errorf("server reachable but API error: %w", err)
}
defer func() { _ = resp2.Body.Close() }()
if resp2.StatusCode == 200 {
return nil
}
bodyBytes, _ := io.ReadAll(io.LimitReader(resp2.Body, 300))
body := string(bodyBytes)
isHTML := strings.HasPrefix(strings.TrimSpace(body), "<")
respServer := strings.ToLower(resp2.Header.Get("Server"))
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
if isHTML || strings.Contains(respServer, "awselb") {
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
}
if resp2.StatusCode == 401 {
return fmt.Errorf("invalid API key or token.\n %s", body)
}
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
}
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
if body != "" {
detail += fmt.Sprintf("\n Response: %s", body)
}
return fmt.Errorf("%s", detail)
}
// ListAgents returns visible agents.
func (c *Client) ListAgents() ([]models.AgentSummary, error) {
var raw []models.AgentSummary
if err := c.doJSON("GET", "/api/persona", nil, &raw); err != nil {
return nil, err
}
var result []models.AgentSummary
for _, p := range raw {
if p.IsVisible {
result = append(result, p)
}
}
return result, nil
}
// ListChatSessions returns recent chat sessions.
func (c *Client) ListChatSessions() ([]models.ChatSessionDetails, error) {
var resp struct {
Sessions []models.ChatSessionDetails `json:"sessions"`
}
if err := c.doJSON("GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
return nil, err
}
return resp.Sessions, nil
}
// GetChatSession returns full details for a session.
func (c *Client) GetChatSession(sessionID string) (*models.ChatSessionDetailResponse, error) {
var resp models.ChatSessionDetailResponse
if err := c.doJSON("GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// RenameChatSession renames a session. If name is empty, the backend auto-generates one.
func (c *Client) RenameChatSession(sessionID string, name *string) (string, error) {
payload := map[string]any{
"chat_session_id": sessionID,
}
if name != nil {
payload["name"] = *name
}
var resp struct {
NewName string `json:"new_name"`
}
if err := c.doJSON("PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
return "", err
}
return resp.NewName, nil
}
// UploadFile uploads a file and returns a file descriptor.
func (c *Client) UploadFile(filePath string) (*models.FileDescriptorPayload, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer func() { _ = file.Close() }()
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("files", filepath.Base(filePath))
if err != nil {
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
return nil, err
}
_ = writer.Close()
req, err := c.newRequest("POST", "/api/user/projects/file/upload", &buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := c.longHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return nil, &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(body)}
}
var snapshot models.CategorizedFilesSnapshot
if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil {
return nil, err
}
if len(snapshot.UserFiles) == 0 {
return nil, &OnyxAPIError{StatusCode: 400, Detail: "File upload returned no files"}
}
uf := snapshot.UserFiles[0]
return &models.FileDescriptorPayload{
ID: uf.FileID,
Type: uf.ChatFileType,
Name: filepath.Base(filePath),
}, nil
}
// StopChatSession sends a stop signal for a streaming session (best-effort).
func (c *Client) StopChatSession(sessionID string) {
req, err := c.newRequest("POST", "/api/chat/stop-chat-session/"+sessionID, nil)
if err != nil {
return
}
resp, err := c.httpClient.Do(req)
if err != nil {
return
}
_ = resp.Body.Close()
}

View File

@@ -0,0 +1,13 @@
package api
import "fmt"
// OnyxAPIError is returned when an Onyx API call fails.
type OnyxAPIError struct {
StatusCode int
Detail string
}
func (e *OnyxAPIError) Error() string {
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
}

136
cli/internal/api/stream.go Normal file
View File

@@ -0,0 +1,136 @@
package api
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/parser"
)
// StreamEventMsg wraps a StreamEvent for Bubble Tea.
type StreamEventMsg struct {
Event models.StreamEvent
}
// StreamDoneMsg signals the stream has ended.
type StreamDoneMsg struct {
Err error
}
// SendMessageStream starts streaming a chat message response.
// It reads NDJSON lines, parses them, and sends events on the returned channel.
// The goroutine stops when ctx is cancelled or the stream ends.
func (c *Client) SendMessageStream(
ctx context.Context,
message string,
chatSessionID *string,
agentID int,
parentMessageID *int,
fileDescriptors []models.FileDescriptorPayload,
) <-chan models.StreamEvent {
ch := make(chan models.StreamEvent, 64)
go func() {
defer close(ch)
payload := models.SendMessagePayload{
Message: message,
ParentMessageID: parentMessageID,
FileDescriptors: fileDescriptors,
Origin: "api",
IncludeCitations: true,
Stream: true,
}
if payload.FileDescriptors == nil {
payload.FileDescriptors = []models.FileDescriptorPayload{}
}
if chatSessionID != nil {
payload.ChatSessionID = chatSessionID
} else {
payload.ChatSessionInfo = &models.ChatSessionCreationInfo{AgentID: agentID}
}
body, err := json.Marshal(payload)
if err != nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("marshal error: %v", err), IsRetryable: false}
return
}
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
if err != nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
return
}
req.Body = io.NopCloser(bytes.NewReader(body))
req.ContentLength = int64(len(body))
req.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
bearer := "Bearer " + c.apiKey
req.Header.Set("Authorization", bearer)
req.Header.Set("X-Onyx-Authorization", bearer)
}
resp, err := c.longHTTPClient.Do(req)
if err != nil {
if ctx.Err() != nil {
return // cancelled
}
ch <- models.ErrorEvent{Error: fmt.Sprintf("connection error: %v", err), IsRetryable: true}
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
var respBody [4096]byte
n, _ := resp.Body.Read(respBody[:])
ch <- models.ErrorEvent{
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody[:n])),
IsRetryable: resp.StatusCode >= 500,
}
return
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
for scanner.Scan() {
if ctx.Err() != nil {
return
}
event := parser.ParseStreamLine(scanner.Text())
if event != nil {
select {
case ch <- event:
case <-ctx.Done():
return
}
}
}
if err := scanner.Err(); err != nil && ctx.Err() == nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("stream read error: %v", err), IsRetryable: true}
}
}()
return ch
}
// WaitForStreamEvent returns a tea.Cmd that reads one event from the channel.
// On channel close, it returns StreamDoneMsg.
func WaitForStreamEvent(ch <-chan models.StreamEvent) tea.Cmd {
return func() tea.Msg {
event, ok := <-ch
if !ok {
return StreamDoneMsg{}
}
return StreamEventMsg{Event: event}
}
}

View File

@@ -0,0 +1,101 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
)
const (
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvAgentID = "ONYX_PERSONA_ID"
)
// OnyxCliConfig holds the CLI configuration.
type OnyxCliConfig struct {
ServerURL string `json:"server_url"`
APIKey string `json:"api_key"`
DefaultAgentID int `json:"default_persona_id"`
}
// DefaultConfig returns a config with default values.
func DefaultConfig() OnyxCliConfig {
return OnyxCliConfig{
ServerURL: "https://cloud.onyx.app",
APIKey: "",
DefaultAgentID: 0,
}
}
// IsConfigured returns true if the config has an API key.
func (c OnyxCliConfig) IsConfigured() bool {
return c.APIKey != ""
}
// configDir returns ~/.config/onyx-cli
func configDir() string {
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
return filepath.Join(xdg, "onyx-cli")
}
home, err := os.UserHomeDir()
if err != nil {
return filepath.Join(".", ".config", "onyx-cli")
}
return filepath.Join(home, ".config", "onyx-cli")
}
// ConfigFilePath returns the full path to the config file.
func ConfigFilePath() string {
return filepath.Join(configDir(), "config.json")
}
// ConfigExists checks if the config file exists on disk.
func ConfigExists() bool {
_, err := os.Stat(ConfigFilePath())
return err == nil
}
// Load reads config from file and applies environment variable overrides.
func Load() OnyxCliConfig {
cfg := DefaultConfig()
data, err := os.ReadFile(ConfigFilePath())
if err == nil {
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
}
}
// Environment overrides
if v := os.Getenv(EnvServerURL); v != "" {
cfg.ServerURL = v
}
if v := os.Getenv(EnvAPIKey); v != "" {
cfg.APIKey = v
}
if v := os.Getenv(EnvAgentID); v != "" {
if id, err := strconv.Atoi(v); err == nil {
cfg.DefaultAgentID = id
}
}
return cfg
}
// Save writes the config to disk, creating parent directories if needed.
func Save(cfg OnyxCliConfig) error {
dir := configDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(ConfigFilePath(), data, 0o600)
}

View File

@@ -0,0 +1,215 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func clearEnvVars(t *testing.T) {
t.Helper()
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID} {
t.Setenv(key, "")
if err := os.Unsetenv(key); err != nil {
t.Fatal(err)
}
}
}
func writeConfig(t *testing.T, dir string, data []byte) {
t.Helper()
onyxDir := filepath.Join(dir, "onyx-cli")
if err := os.MkdirAll(onyxDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644); err != nil {
t.Fatal(err)
}
}
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default server URL, got %s", cfg.ServerURL)
}
if cfg.APIKey != "" {
t.Errorf("expected empty API key, got %s", cfg.APIKey)
}
if cfg.DefaultAgentID != 0 {
t.Errorf("expected default agent ID 0, got %d", cfg.DefaultAgentID)
}
}
func TestIsConfigured(t *testing.T) {
cfg := DefaultConfig()
if cfg.IsConfigured() {
t.Error("empty config should not be configured")
}
cfg.APIKey = "some-key"
if !cfg.IsConfigured() {
t.Error("config with API key should be configured")
}
}
func TestLoadDefaults(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
cfg := Load()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default URL, got %s", cfg.ServerURL)
}
if cfg.APIKey != "" {
t.Errorf("expected empty key, got %s", cfg.APIKey)
}
}
func TestLoadFromFile(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
data, _ := json.Marshal(map[string]interface{}{
"server_url": "https://my-onyx.example.com",
"api_key": "test-key-123",
"default_persona_id": 5,
})
writeConfig(t, dir, data)
cfg := Load()
if cfg.ServerURL != "https://my-onyx.example.com" {
t.Errorf("got %s", cfg.ServerURL)
}
if cfg.APIKey != "test-key-123" {
t.Errorf("got %s", cfg.APIKey)
}
if cfg.DefaultAgentID != 5 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestLoadCorruptFile(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
writeConfig(t, dir, []byte("not valid json {{{"))
cfg := Load()
if cfg.ServerURL != "https://cloud.onyx.app" {
t.Errorf("expected default URL on corrupt file, got %s", cfg.ServerURL)
}
}
func TestEnvOverrideServerURL(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvServerURL, "https://env-override.com")
cfg := Load()
if cfg.ServerURL != "https://env-override.com" {
t.Errorf("got %s", cfg.ServerURL)
}
}
func TestEnvOverrideAPIKey(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAPIKey, "env-key")
cfg := Load()
if cfg.APIKey != "env-key" {
t.Errorf("got %s", cfg.APIKey)
}
}
func TestEnvOverrideAgentID(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAgentID, "42")
cfg := Load()
if cfg.DefaultAgentID != 42 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestEnvOverrideInvalidAgentID(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
t.Setenv(EnvAgentID, "not-a-number")
cfg := Load()
if cfg.DefaultAgentID != 0 {
t.Errorf("got %d", cfg.DefaultAgentID)
}
}
func TestEnvOverridesFileValues(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
data, _ := json.Marshal(map[string]interface{}{
"server_url": "https://file-url.com",
"api_key": "file-key",
})
writeConfig(t, dir, data)
t.Setenv(EnvServerURL, "https://env-url.com")
cfg := Load()
if cfg.ServerURL != "https://env-url.com" {
t.Errorf("env should override file, got %s", cfg.ServerURL)
}
if cfg.APIKey != "file-key" {
t.Errorf("file value should be kept, got %s", cfg.APIKey)
}
}
func TestSaveAndReload(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)
cfg := OnyxCliConfig{
ServerURL: "https://saved.example.com",
APIKey: "saved-key",
DefaultAgentID: 10,
}
if err := Save(cfg); err != nil {
t.Fatal(err)
}
loaded := Load()
if loaded.ServerURL != "https://saved.example.com" {
t.Errorf("got %s", loaded.ServerURL)
}
if loaded.APIKey != "saved-key" {
t.Errorf("got %s", loaded.APIKey)
}
if loaded.DefaultAgentID != 10 {
t.Errorf("got %d", loaded.DefaultAgentID)
}
}
func TestSaveCreatesParentDirs(t *testing.T) {
clearEnvVars(t)
dir := t.TempDir()
nested := filepath.Join(dir, "deep", "nested")
t.Setenv("XDG_CONFIG_HOME", nested)
if err := Save(OnyxCliConfig{APIKey: "test"}); err != nil {
t.Fatal(err)
}
if !ConfigExists() {
t.Error("config file should exist after save")
}
}

View File

@@ -0,0 +1,193 @@
package models
// StreamEvent is the interface for all parsed stream events.
type StreamEvent interface {
EventType() string
}
// Event type constants matching the Python StreamEventType enum.
const (
EventSessionCreated = "session_created"
EventMessageIDInfo = "message_id_info"
EventStop = "stop"
EventError = "error"
EventMessageStart = "message_start"
EventMessageDelta = "message_delta"
EventSearchStart = "search_tool_start"
EventSearchQueries = "search_tool_queries_delta"
EventSearchDocuments = "search_tool_documents_delta"
EventReasoningStart = "reasoning_start"
EventReasoningDelta = "reasoning_delta"
EventReasoningDone = "reasoning_done"
EventCitationInfo = "citation_info"
EventOpenURLStart = "open_url_start"
EventImageGenStart = "image_generation_start"
EventPythonToolStart = "python_tool_start"
EventCustomToolStart = "custom_tool_start"
EventFileReaderStart = "file_reader_start"
EventDeepResearchPlan = "deep_research_plan_start"
EventDeepResearchDelta = "deep_research_plan_delta"
EventResearchAgentStart = "research_agent_start"
EventIntermediateReport = "intermediate_report_start"
EventIntermediateReportDt = "intermediate_report_delta"
EventUnknown = "unknown"
)
// SessionCreatedEvent is emitted when a new chat session is created.
type SessionCreatedEvent struct {
ChatSessionID string
}
func (e SessionCreatedEvent) EventType() string { return EventSessionCreated }
// MessageIDEvent carries the user and agent message IDs.
type MessageIDEvent struct {
UserMessageID *int
ReservedAgentMessageID int
}
func (e MessageIDEvent) EventType() string { return EventMessageIDInfo }
// StopEvent signals the end of a stream.
type StopEvent struct {
Placement *Placement
StopReason *string
}
func (e StopEvent) EventType() string { return EventStop }
// ErrorEvent signals an error.
type ErrorEvent struct {
Placement *Placement
Error string
StackTrace *string
IsRetryable bool
}
func (e ErrorEvent) EventType() string { return EventError }
// MessageStartEvent signals the beginning of an agent message.
type MessageStartEvent struct {
Placement *Placement
Documents []SearchDoc
}
func (e MessageStartEvent) EventType() string { return EventMessageStart }
// MessageDeltaEvent carries a token of agent content.
type MessageDeltaEvent struct {
Placement *Placement
Content string
}
func (e MessageDeltaEvent) EventType() string { return EventMessageDelta }
// SearchStartEvent signals the beginning of a search.
type SearchStartEvent struct {
Placement *Placement
IsInternetSearch bool
}
func (e SearchStartEvent) EventType() string { return EventSearchStart }
// SearchQueriesEvent carries search queries.
type SearchQueriesEvent struct {
Placement *Placement
Queries []string
}
func (e SearchQueriesEvent) EventType() string { return EventSearchQueries }
// SearchDocumentsEvent carries found documents.
type SearchDocumentsEvent struct {
Placement *Placement
Documents []SearchDoc
}
func (e SearchDocumentsEvent) EventType() string { return EventSearchDocuments }
// ReasoningStartEvent signals the beginning of a reasoning block.
type ReasoningStartEvent struct {
Placement *Placement
}
func (e ReasoningStartEvent) EventType() string { return EventReasoningStart }
// ReasoningDeltaEvent carries reasoning text.
type ReasoningDeltaEvent struct {
Placement *Placement
Reasoning string
}
func (e ReasoningDeltaEvent) EventType() string { return EventReasoningDelta }
// ReasoningDoneEvent signals the end of reasoning.
type ReasoningDoneEvent struct {
Placement *Placement
}
func (e ReasoningDoneEvent) EventType() string { return EventReasoningDone }
// CitationEvent carries citation info.
type CitationEvent struct {
Placement *Placement
CitationNumber int
DocumentID string
}
func (e CitationEvent) EventType() string { return EventCitationInfo }
// ToolStartEvent signals the start of a tool usage.
type ToolStartEvent struct {
Placement *Placement
Type string // The specific event type (e.g. "open_url_start")
ToolName string
}
func (e ToolStartEvent) EventType() string { return e.Type }
// DeepResearchPlanStartEvent signals the start of a deep research plan.
type DeepResearchPlanStartEvent struct {
Placement *Placement
}
func (e DeepResearchPlanStartEvent) EventType() string { return EventDeepResearchPlan }
// DeepResearchPlanDeltaEvent carries deep research plan content.
type DeepResearchPlanDeltaEvent struct {
Placement *Placement
Content string
}
func (e DeepResearchPlanDeltaEvent) EventType() string { return EventDeepResearchDelta }
// ResearchAgentStartEvent signals a research sub-task.
type ResearchAgentStartEvent struct {
Placement *Placement
ResearchTask string
}
func (e ResearchAgentStartEvent) EventType() string { return EventResearchAgentStart }
// IntermediateReportStartEvent signals the start of an intermediate report.
type IntermediateReportStartEvent struct {
Placement *Placement
}
func (e IntermediateReportStartEvent) EventType() string { return EventIntermediateReport }
// IntermediateReportDeltaEvent carries intermediate report content.
type IntermediateReportDeltaEvent struct {
Placement *Placement
Content string
}
func (e IntermediateReportDeltaEvent) EventType() string { return EventIntermediateReportDt }
// UnknownEvent is a catch-all for unrecognized stream data.
type UnknownEvent struct {
Placement *Placement
RawData map[string]any
}
func (e UnknownEvent) EventType() string { return EventUnknown }

View File

@@ -0,0 +1,112 @@
// Package models defines API request/response types for the Onyx CLI.
package models
import "time"
// AgentSummary represents an agent from the API.
type AgentSummary struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
IsDefaultPersona bool `json:"is_default_persona"`
IsVisible bool `json:"is_visible"`
}
// ChatSessionSummary is a brief session listing.
type ChatSessionSummary struct {
ID string `json:"id"`
Name *string `json:"name"`
AgentID *int `json:"persona_id"`
Created time.Time `json:"time_created"`
}
// ChatSessionDetails is a session with timestamps as strings.
type ChatSessionDetails struct {
ID string `json:"id"`
Name *string `json:"name"`
AgentID *int `json:"persona_id"`
Created string `json:"time_created"`
Updated string `json:"time_updated"`
}
// ChatMessageDetail is a single message in a session.
type ChatMessageDetail struct {
MessageID int `json:"message_id"`
ParentMessage *int `json:"parent_message"`
LatestChildMessage *int `json:"latest_child_message"`
Message string `json:"message"`
MessageType string `json:"message_type"`
TimeSent string `json:"time_sent"`
Error *string `json:"error"`
}
// ChatSessionDetailResponse is the full session detail from the API.
type ChatSessionDetailResponse struct {
ChatSessionID string `json:"chat_session_id"`
Description *string `json:"description"`
AgentID *int `json:"persona_id"`
AgentName *string `json:"persona_name"`
Messages []ChatMessageDetail `json:"messages"`
}
// ChatFileType represents a file type for uploads.
type ChatFileType string
const (
ChatFileImage ChatFileType = "image"
ChatFileDoc ChatFileType = "document"
ChatFilePlainText ChatFileType = "plain_text"
ChatFileCSV ChatFileType = "csv"
)
// FileDescriptorPayload is a file descriptor for send-message requests.
type FileDescriptorPayload struct {
ID string `json:"id"`
Type ChatFileType `json:"type"`
Name string `json:"name,omitempty"`
}
// UserFileSnapshot represents an uploaded file.
type UserFileSnapshot struct {
ID string `json:"id"`
Name string `json:"name"`
FileID string `json:"file_id"`
ChatFileType ChatFileType `json:"chat_file_type"`
}
// CategorizedFilesSnapshot is the response from file upload.
type CategorizedFilesSnapshot struct {
UserFiles []UserFileSnapshot `json:"user_files"`
}
// ChatSessionCreationInfo is included when creating a new session inline.
type ChatSessionCreationInfo struct {
AgentID int `json:"persona_id"`
}
// SendMessagePayload is the request body for POST /api/chat/send-chat-message.
type SendMessagePayload struct {
Message string `json:"message"`
ChatSessionID *string `json:"chat_session_id,omitempty"`
ChatSessionInfo *ChatSessionCreationInfo `json:"chat_session_info,omitempty"`
ParentMessageID *int `json:"parent_message_id"`
FileDescriptors []FileDescriptorPayload `json:"file_descriptors"`
Origin string `json:"origin"`
IncludeCitations bool `json:"include_citations"`
Stream bool `json:"stream"`
}
// SearchDoc represents a document found during search.
type SearchDoc struct {
DocumentID string `json:"document_id"`
SemanticIdentifier string `json:"semantic_identifier"`
Link *string `json:"link"`
SourceType string `json:"source_type"`
}
// Placement indicates where a stream event belongs in the conversation.
type Placement struct {
TurnIndex int `json:"turn_index"`
TabIndex int `json:"tab_index"`
SubTurnIndex *int `json:"sub_turn_index"`
}

View File

@@ -0,0 +1,169 @@
// Package onboarding handles the first-run setup flow for Onyx CLI.
package onboarding
import (
"bufio"
"fmt"
"os"
"strings"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/onyx-dot-app/onyx/cli/internal/util"
"golang.org/x/term"
)
// Aliases for shared styles.
var (
boldStyle = util.BoldStyle
dimStyle = util.DimStyle
greenStyle = util.GreenStyle
redStyle = util.RedStyle
yellowStyle = util.YellowStyle
)
func getTermSize() (int, int) {
w, h, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
return 80, 24
}
return w, h
}
// Run executes the interactive onboarding flow.
// Returns the validated config, or nil if the user cancels.
func Run(existing *config.OnyxCliConfig) *config.OnyxCliConfig {
cfg := config.DefaultConfig()
if existing != nil {
cfg = *existing
}
w, h := getTermSize()
fmt.Print(tui.RenderSplashOnboarding(w, h))
fmt.Println()
fmt.Println(" Welcome to " + boldStyle.Render("Onyx CLI") + ".")
fmt.Println()
reader := bufio.NewReader(os.Stdin)
// Server URL
serverURL := prompt(reader, " Onyx server URL", cfg.ServerURL)
if serverURL == "" {
return nil
}
if !strings.HasPrefix(serverURL, "http://") && !strings.HasPrefix(serverURL, "https://") {
fmt.Println(" " + redStyle.Render("Server URL must start with http:// or https://"))
return nil
}
// API Key
fmt.Println()
fmt.Println(" " + dimStyle.Render("Need an API key? Press Enter to open the admin panel in your browser,"))
fmt.Println(" " + dimStyle.Render("or paste your key below."))
fmt.Println()
apiKey := promptSecret(" API key", cfg.APIKey)
if apiKey == "" {
// Open browser to API key page
url := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
fmt.Printf("\n Opening %s ...\n", url)
util.OpenBrowser(url)
fmt.Println(" " + dimStyle.Render("Copy your API key, then paste it here."))
fmt.Println()
apiKey = promptSecret(" API key", "")
if apiKey == "" {
fmt.Println("\n " + redStyle.Render("No API key provided. Exiting."))
return nil
}
}
// Test connection
cfg = config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: apiKey,
DefaultAgentID: cfg.DefaultAgentID,
}
fmt.Println("\n " + yellowStyle.Render("Testing connection..."))
client := api.NewClient(cfg)
if err := client.TestConnection(); err != nil {
fmt.Println(" " + redStyle.Render("Connection failed.") + " " + err.Error())
fmt.Println()
fmt.Println(" " + dimStyle.Render("Run ") + boldStyle.Render("onyx-cli configure") + dimStyle.Render(" to try again."))
return nil
}
if err := config.Save(cfg); err != nil {
fmt.Println(" " + redStyle.Render("Could not save config: "+err.Error()))
return nil
}
fmt.Println(" " + greenStyle.Render("Connected and authenticated."))
fmt.Println()
printQuickStart()
return &cfg
}
func promptSecret(label, defaultVal string) string {
if defaultVal != "" {
fmt.Printf("%s %s: ", label, dimStyle.Render("[hidden]"))
} else {
fmt.Printf("%s: ", label)
}
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println() // ReadPassword doesn't echo a newline
if err != nil {
return defaultVal
}
line := strings.TrimSpace(string(password))
if line == "" {
return defaultVal
}
return line
}
func prompt(reader *bufio.Reader, label, defaultVal string) string {
if defaultVal != "" {
fmt.Printf("%s %s: ", label, dimStyle.Render("["+defaultVal+"]"))
} else {
fmt.Printf("%s: ", label)
}
line, err := reader.ReadString('\n')
// ReadString may return partial data along with an error (e.g. EOF without newline)
line = strings.TrimSpace(line)
if line != "" {
return line
}
if err != nil {
return defaultVal
}
return defaultVal
}
func printQuickStart() {
fmt.Println(" " + boldStyle.Render("Quick start"))
fmt.Println()
fmt.Println(" Just type to chat with your Onyx agent.")
fmt.Println()
rows := [][2]string{
{"/help", "Show all commands"},
{"/attach", "Attach a file"},
{"/agent", "Switch agent"},
{"/new", "New conversation"},
{"/sessions", "Browse previous chats"},
{"Esc", "Cancel generation"},
{"Ctrl+D", "Quit"},
}
for _, r := range rows {
fmt.Printf(" %-12s %s\n", boldStyle.Render(r[0]), dimStyle.Render(r[1]))
}
fmt.Println()
}

View File

@@ -0,0 +1,248 @@
// Package parser handles NDJSON stream parsing for Onyx chat responses.
package parser
import (
"encoding/json"
"fmt"
"strings"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// ParseStreamLine parses a single NDJSON line into a typed StreamEvent.
// Returns nil for empty lines or unparseable content.
func ParseStreamLine(line string) models.StreamEvent {
line = strings.TrimSpace(line)
if line == "" {
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(line), &data); err != nil {
return models.ErrorEvent{Error: fmt.Sprintf("malformed stream data: %v", err), IsRetryable: false}
}
// Case 1: CreateChatSessionID
if _, ok := data["chat_session_id"]; ok {
if _, hasPlacement := data["placement"]; !hasPlacement {
sid, _ := data["chat_session_id"].(string)
return models.SessionCreatedEvent{ChatSessionID: sid}
}
}
// Case 2: MessageResponseIDInfo
if _, ok := data["reserved_assistant_message_id"]; ok {
reservedID := jsonInt(data["reserved_assistant_message_id"])
var userMsgID *int
if v, ok := data["user_message_id"]; ok && v != nil {
id := jsonInt(v)
userMsgID = &id
}
return models.MessageIDEvent{
UserMessageID: userMsgID,
ReservedAgentMessageID: reservedID,
}
}
// Case 3: StreamingError (top-level error without placement)
if _, ok := data["error"]; ok {
if _, hasPlacement := data["placement"]; !hasPlacement {
errStr, _ := data["error"].(string)
var stackTrace *string
if st, ok := data["stack_trace"].(string); ok {
stackTrace = &st
}
isRetryable := true
if v, ok := data["is_retryable"].(bool); ok {
isRetryable = v
}
return models.ErrorEvent{
Error: errStr,
StackTrace: stackTrace,
IsRetryable: isRetryable,
}
}
}
// Case 4: Packet with placement + obj
if rawPlacement, ok := data["placement"]; ok {
if rawObj, ok := data["obj"]; ok {
placement := parsePlacement(rawPlacement)
obj, _ := rawObj.(map[string]any)
if obj == nil {
return models.UnknownEvent{Placement: placement, RawData: data}
}
return parsePacketObj(obj, placement)
}
}
// Fallback
return models.UnknownEvent{RawData: data}
}
func parsePlacement(raw interface{}) *models.Placement {
m, ok := raw.(map[string]any)
if !ok {
return nil
}
p := &models.Placement{
TurnIndex: jsonInt(m["turn_index"]),
TabIndex: jsonInt(m["tab_index"]),
}
if v, ok := m["sub_turn_index"]; ok && v != nil {
st := jsonInt(v)
p.SubTurnIndex = &st
}
return p
}
func parsePacketObj(obj map[string]any, placement *models.Placement) models.StreamEvent {
objType, _ := obj["type"].(string)
switch objType {
case "stop":
var reason *string
if r, ok := obj["stop_reason"].(string); ok {
reason = &r
}
return models.StopEvent{Placement: placement, StopReason: reason}
case "error":
errMsg := "Unknown error"
if e, ok := obj["exception"]; ok {
errMsg = toString(e)
}
return models.ErrorEvent{Placement: placement, Error: errMsg, IsRetryable: true}
case "message_start":
var docs []models.SearchDoc
if rawDocs, ok := obj["final_documents"].([]any); ok {
docs = parseSearchDocs(rawDocs)
}
return models.MessageStartEvent{Placement: placement, Documents: docs}
case "message_delta":
content, _ := obj["content"].(string)
return models.MessageDeltaEvent{Placement: placement, Content: content}
case "search_tool_start":
isInternet, _ := obj["is_internet_search"].(bool)
return models.SearchStartEvent{Placement: placement, IsInternetSearch: isInternet}
case "search_tool_queries_delta":
var queries []string
if raw, ok := obj["queries"].([]any); ok {
for _, q := range raw {
if s, ok := q.(string); ok {
queries = append(queries, s)
}
}
}
return models.SearchQueriesEvent{Placement: placement, Queries: queries}
case "search_tool_documents_delta":
var docs []models.SearchDoc
if rawDocs, ok := obj["documents"].([]any); ok {
docs = parseSearchDocs(rawDocs)
}
return models.SearchDocumentsEvent{Placement: placement, Documents: docs}
case "reasoning_start":
return models.ReasoningStartEvent{Placement: placement}
case "reasoning_delta":
reasoning, _ := obj["reasoning"].(string)
return models.ReasoningDeltaEvent{Placement: placement, Reasoning: reasoning}
case "reasoning_done":
return models.ReasoningDoneEvent{Placement: placement}
case "citation_info":
return models.CitationEvent{
Placement: placement,
CitationNumber: jsonInt(obj["citation_number"]),
DocumentID: jsonString(obj["document_id"]),
}
case "open_url_start", "image_generation_start", "python_tool_start", "file_reader_start":
toolName := strings.ReplaceAll(strings.TrimSuffix(objType, "_start"), "_", " ")
toolName = cases.Title(language.English).String(toolName)
return models.ToolStartEvent{Placement: placement, Type: objType, ToolName: toolName}
case "custom_tool_start":
toolName := jsonString(obj["tool_name"])
if toolName == "" {
toolName = "Custom Tool"
}
return models.ToolStartEvent{Placement: placement, Type: models.EventCustomToolStart, ToolName: toolName}
case "deep_research_plan_start":
return models.DeepResearchPlanStartEvent{Placement: placement}
case "deep_research_plan_delta":
content, _ := obj["content"].(string)
return models.DeepResearchPlanDeltaEvent{Placement: placement, Content: content}
case "research_agent_start":
task, _ := obj["research_task"].(string)
return models.ResearchAgentStartEvent{Placement: placement, ResearchTask: task}
case "intermediate_report_start":
return models.IntermediateReportStartEvent{Placement: placement}
case "intermediate_report_delta":
content, _ := obj["content"].(string)
return models.IntermediateReportDeltaEvent{Placement: placement, Content: content}
default:
return models.UnknownEvent{Placement: placement, RawData: obj}
}
}
func parseSearchDocs(raw []any) []models.SearchDoc {
var docs []models.SearchDoc
for _, item := range raw {
m, ok := item.(map[string]any)
if !ok {
continue
}
doc := models.SearchDoc{
DocumentID: jsonString(m["document_id"]),
SemanticIdentifier: jsonString(m["semantic_identifier"]),
SourceType: jsonString(m["source_type"]),
}
if link, ok := m["link"].(string); ok {
doc.Link = &link
}
docs = append(docs, doc)
}
return docs
}
func jsonInt(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
default:
return 0
}
}
func jsonString(v any) string {
s, _ := v.(string)
return s
}
func toString(v any) string {
switch s := v.(type) {
case string:
return s
default:
b, _ := json.Marshal(v)
return string(b)
}
}

View File

@@ -0,0 +1,419 @@
package parser
import (
"encoding/json"
"testing"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
func TestEmptyLineReturnsNil(t *testing.T) {
for _, line := range []string{"", " ", "\n"} {
if ParseStreamLine(line) != nil {
t.Errorf("expected nil for %q", line)
}
}
}
func TestInvalidJSONReturnsErrorEvent(t *testing.T) {
for _, line := range []string{"not json", "{broken"} {
event := ParseStreamLine(line)
if event == nil {
t.Errorf("expected ErrorEvent for %q, got nil", line)
continue
}
if _, ok := event.(models.ErrorEvent); !ok {
t.Errorf("expected ErrorEvent for %q, got %T", line, event)
}
}
}
func TestSessionCreated(t *testing.T) {
line := mustJSON(map[string]interface{}{
"chat_session_id": "550e8400-e29b-41d4-a716-446655440000",
})
event := ParseStreamLine(line)
e, ok := event.(models.SessionCreatedEvent)
if !ok {
t.Fatalf("expected SessionCreatedEvent, got %T", event)
}
if e.ChatSessionID != "550e8400-e29b-41d4-a716-446655440000" {
t.Errorf("got %s", e.ChatSessionID)
}
}
func TestMessageIDInfo(t *testing.T) {
line := mustJSON(map[string]interface{}{
"user_message_id": 1,
"reserved_assistant_message_id": 2,
})
event := ParseStreamLine(line)
e, ok := event.(models.MessageIDEvent)
if !ok {
t.Fatalf("expected MessageIDEvent, got %T", event)
}
if e.UserMessageID == nil || *e.UserMessageID != 1 {
t.Errorf("expected user_message_id=1")
}
if e.ReservedAgentMessageID != 2 {
t.Errorf("got %d", e.ReservedAgentMessageID)
}
}
func TestMessageIDInfoNullUserID(t *testing.T) {
line := mustJSON(map[string]interface{}{
"user_message_id": nil,
"reserved_assistant_message_id": 5,
})
event := ParseStreamLine(line)
e, ok := event.(models.MessageIDEvent)
if !ok {
t.Fatalf("expected MessageIDEvent, got %T", event)
}
if e.UserMessageID != nil {
t.Error("expected nil user_message_id")
}
if e.ReservedAgentMessageID != 5 {
t.Errorf("got %d", e.ReservedAgentMessageID)
}
}
func TestTopLevelError(t *testing.T) {
line := mustJSON(map[string]interface{}{
"error": "Rate limit exceeded",
"stack_trace": "...",
"is_retryable": true,
})
event := ParseStreamLine(line)
e, ok := event.(models.ErrorEvent)
if !ok {
t.Fatalf("expected ErrorEvent, got %T", event)
}
if e.Error != "Rate limit exceeded" {
t.Errorf("got %s", e.Error)
}
if e.StackTrace == nil || *e.StackTrace != "..." {
t.Error("expected stack_trace")
}
if !e.IsRetryable {
t.Error("expected retryable")
}
}
func TestTopLevelErrorMinimal(t *testing.T) {
line := mustJSON(map[string]interface{}{
"error": "Something broke",
})
event := ParseStreamLine(line)
e, ok := event.(models.ErrorEvent)
if !ok {
t.Fatalf("expected ErrorEvent, got %T", event)
}
if e.Error != "Something broke" {
t.Errorf("got %s", e.Error)
}
if !e.IsRetryable {
t.Error("expected default retryable=true")
}
}
func makePacket(obj map[string]interface{}, turnIndex, tabIndex int) string {
return mustJSON(map[string]interface{}{
"placement": map[string]interface{}{"turn_index": turnIndex, "tab_index": tabIndex},
"obj": obj,
})
}
func TestStopPacket(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "stop", "stop_reason": "completed"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.StopEvent)
if !ok {
t.Fatalf("expected StopEvent, got %T", event)
}
if e.StopReason == nil || *e.StopReason != "completed" {
t.Error("expected stop_reason=completed")
}
if e.Placement == nil || e.Placement.TurnIndex != 0 {
t.Error("expected placement")
}
}
func TestStopPacketNoReason(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "stop"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.StopEvent)
if !ok {
t.Fatalf("expected StopEvent, got %T", event)
}
if e.StopReason != nil {
t.Error("expected nil stop_reason")
}
}
func TestMessageStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_start"}, 0, 0)
event := ParseStreamLine(line)
_, ok := event.(models.MessageStartEvent)
if !ok {
t.Fatalf("expected MessageStartEvent, got %T", event)
}
}
func TestMessageStartWithDocuments(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "message_start",
"final_documents": []interface{}{
map[string]interface{}{"document_id": "doc1", "semantic_identifier": "Doc 1"},
},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageStartEvent)
if !ok {
t.Fatalf("expected MessageStartEvent, got %T", event)
}
if len(e.Documents) != 1 || e.Documents[0].DocumentID != "doc1" {
t.Error("expected 1 document with id doc1")
}
}
func TestMessageDelta(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_delta", "content": "Hello"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Content != "Hello" {
t.Errorf("got %s", e.Content)
}
}
func TestMessageDeltaEmpty(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "message_delta", "content": ""}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Content != "" {
t.Errorf("expected empty, got %s", e.Content)
}
}
func TestSearchToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_start", "is_internet_search": true,
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchStartEvent)
if !ok {
t.Fatalf("expected SearchStartEvent, got %T", event)
}
if !e.IsInternetSearch {
t.Error("expected internet search")
}
}
func TestSearchToolQueries(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_queries_delta",
"queries": []interface{}{"query 1", "query 2"},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchQueriesEvent)
if !ok {
t.Fatalf("expected SearchQueriesEvent, got %T", event)
}
if len(e.Queries) != 2 || e.Queries[0] != "query 1" {
t.Error("unexpected queries")
}
}
func TestSearchToolDocuments(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "search_tool_documents_delta",
"documents": []interface{}{
map[string]interface{}{"document_id": "d1", "semantic_identifier": "First Doc", "link": "http://example.com"},
map[string]interface{}{"document_id": "d2", "semantic_identifier": "Second Doc"},
},
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.SearchDocumentsEvent)
if !ok {
t.Fatalf("expected SearchDocumentsEvent, got %T", event)
}
if len(e.Documents) != 2 {
t.Errorf("expected 2 docs, got %d", len(e.Documents))
}
if e.Documents[0].Link == nil || *e.Documents[0].Link != "http://example.com" {
t.Error("expected link on first doc")
}
}
func TestReasoningStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "reasoning_start"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.ReasoningStartEvent); !ok {
t.Fatalf("expected ReasoningStartEvent, got %T", event)
}
}
func TestReasoningDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "reasoning_delta", "reasoning": "Let me think...",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ReasoningDeltaEvent)
if !ok {
t.Fatalf("expected ReasoningDeltaEvent, got %T", event)
}
if e.Reasoning != "Let me think..." {
t.Errorf("got %s", e.Reasoning)
}
}
func TestReasoningDone(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "reasoning_done"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.ReasoningDoneEvent); !ok {
t.Fatalf("expected ReasoningDoneEvent, got %T", event)
}
}
func TestCitationInfo(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "citation_info", "citation_number": 1, "document_id": "doc_abc",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.CitationEvent)
if !ok {
t.Fatalf("expected CitationEvent, got %T", event)
}
if e.CitationNumber != 1 || e.DocumentID != "doc_abc" {
t.Errorf("got %d, %s", e.CitationNumber, e.DocumentID)
}
}
func TestOpenURLStart(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "open_url_start"}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.Type != "open_url_start" {
t.Errorf("got type %s", e.Type)
}
}
func TestPythonToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "python_tool_start", "code": "print('hi')",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.ToolName != "Python Tool" {
t.Errorf("got %s", e.ToolName)
}
}
func TestCustomToolStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "custom_tool_start", "tool_name": "MyTool",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ToolStartEvent)
if !ok {
t.Fatalf("expected ToolStartEvent, got %T", event)
}
if e.ToolName != "MyTool" {
t.Errorf("got %s", e.ToolName)
}
}
func TestDeepResearchPlanDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "deep_research_plan_delta", "content": "Step 1: ...",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.DeepResearchPlanDeltaEvent)
if !ok {
t.Fatalf("expected DeepResearchPlanDeltaEvent, got %T", event)
}
if e.Content != "Step 1: ..." {
t.Errorf("got %s", e.Content)
}
}
func TestResearchAgentStart(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "research_agent_start", "research_task": "Find info about X",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.ResearchAgentStartEvent)
if !ok {
t.Fatalf("expected ResearchAgentStartEvent, got %T", event)
}
if e.ResearchTask != "Find info about X" {
t.Errorf("got %s", e.ResearchTask)
}
}
func TestIntermediateReportDelta(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "intermediate_report_delta", "content": "Report text",
}, 0, 0)
event := ParseStreamLine(line)
e, ok := event.(models.IntermediateReportDeltaEvent)
if !ok {
t.Fatalf("expected IntermediateReportDeltaEvent, got %T", event)
}
if e.Content != "Report text" {
t.Errorf("got %s", e.Content)
}
}
func TestUnknownPacketType(t *testing.T) {
line := makePacket(map[string]interface{}{"type": "section_end"}, 0, 0)
event := ParseStreamLine(line)
if _, ok := event.(models.UnknownEvent); !ok {
t.Fatalf("expected UnknownEvent, got %T", event)
}
}
func TestUnknownTopLevel(t *testing.T) {
line := mustJSON(map[string]interface{}{"some_unknown_field": "value"})
event := ParseStreamLine(line)
if _, ok := event.(models.UnknownEvent); !ok {
t.Fatalf("expected UnknownEvent, got %T", event)
}
}
func TestPlacementPreserved(t *testing.T) {
line := makePacket(map[string]interface{}{
"type": "message_delta", "content": "x",
}, 3, 1)
event := ParseStreamLine(line)
e, ok := event.(models.MessageDeltaEvent)
if !ok {
t.Fatalf("expected MessageDeltaEvent, got %T", event)
}
if e.Placement == nil {
t.Fatal("expected placement")
}
if e.Placement.TurnIndex != 3 || e.Placement.TabIndex != 1 {
t.Errorf("got turn=%d tab=%d", e.Placement.TurnIndex, e.Placement.TabIndex)
}
}
func mustJSON(v interface{}) string {
b, err := json.Marshal(v)
if err != nil {
panic(err)
}
return string(b)
}

630
cli/internal/tui/app.go Normal file
View File

@@ -0,0 +1,630 @@
// Package tui implements the Bubble Tea TUI for Onyx CLI.
package tui
import (
"context"
"fmt"
"strconv"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// Model is the root Bubble Tea model.
type Model struct {
config config.OnyxCliConfig
client *api.Client
viewport *viewport
input inputModel
status statusBar
width int
height int
// Chat state
chatSessionID *string
agentID int
agentName string
agents []models.AgentSummary
parentMessageID *int
isStreaming bool
streamCancel context.CancelFunc
streamCh <-chan models.StreamEvent
citations map[int]string
attachedFiles []models.FileDescriptorPayload
needsRename bool
agentStarted bool
// Quit state
quitPending bool
splashShown bool
initInputReady bool // true once terminal init responses have passed
}
// NewModel creates a new TUI model.
func NewModel(cfg config.OnyxCliConfig) Model {
client := api.NewClient(cfg)
parentID := -1
return Model{
config: cfg,
client: client,
viewport: newViewport(80),
input: newInputModel(),
status: newStatusBar(),
agentID: cfg.DefaultAgentID,
agentName: "Default",
parentMessageID: &parentID,
citations: make(map[int]string),
}
}
// Init initializes the model.
func (m Model) Init() tea.Cmd {
return loadAgentsCmd(m.client)
}
// Update handles messages.
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Filter out terminal query responses (OSC 11 background color, cursor
// position reports, etc.) that arrive as key events with raw escape content.
// These arrive split across multiple key events, so we use a brief window
// after startup to swallow them all.
if keyMsg, ok := msg.(tea.KeyMsg); ok && !m.initInputReady {
// During init, drop ALL key events — they're terminal query responses
_ = keyMsg
return m, nil
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.viewport.setWidth(msg.Width)
m.status.setWidth(msg.Width)
m.input.textInput.Width = msg.Width - 4
if !m.splashShown {
m.splashShown = true
// bottomHeight = sep + input + sep + status = 4 (approx)
viewportHeight := msg.Height - 4
if viewportHeight < 1 {
viewportHeight = msg.Height
}
m.viewport.addSplash(viewportHeight)
// Delay input focus to let terminal query responses flush
return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg {
return inputReadyMsg{}
})
}
return m, nil
case tea.MouseMsg:
switch msg.Button {
case tea.MouseButtonWheelUp:
m.viewport.scrollUp(3)
return m, nil
case tea.MouseButtonWheelDown:
m.viewport.scrollDown(3)
return m, nil
}
case tea.KeyMsg:
return m.handleKey(msg)
case submitMsg:
return m.handleSubmit(msg.text)
case fileDropMsg:
return m.handleFileDrop(msg.path)
case InitDoneMsg:
return m.handleInitDone(msg)
case api.StreamEventMsg:
return m.handleStreamEvent(msg)
case api.StreamDoneMsg:
return m.handleStreamDone(msg)
case AgentsLoadedMsg:
return m.handleAgentsLoaded(msg)
case SessionsLoadedMsg:
return m.handleSessionsLoaded(msg)
case SessionResumedMsg:
return m.handleSessionResumed(msg)
case FileUploadedMsg:
return m.handleFileUploaded(msg)
case inputReadyMsg:
m.initInputReady = true
m.input.textInput.Focus()
m.input.textInput.SetValue("")
return m, m.input.textInput.Cursor.BlinkCmd()
case resetQuitMsg:
m.quitPending = false
return m, nil
}
// Only forward messages to the text input after it's been focused
if m.splashShown {
var cmd tea.Cmd
m.input, cmd = m.input.update(msg)
return m, cmd
}
return m, nil
}
// View renders the UI.
// viewportHeight returns the number of visible chat rows, accounting for the
// dynamic bottom area (separator, menu, file badges, input, status bar).
func (m Model) viewportHeight() int {
menuHeight := 0
if m.input.menuVisible {
menuHeight = len(m.input.menuItems)
}
fileHeight := 0
if len(m.input.attachedFiles) > 0 {
fileHeight = 1
}
h := m.height - (1 + menuHeight + fileHeight + 1 + 1 + 1)
if h < 1 {
return 1
}
return h
}
func (m Model) View() string {
if m.width == 0 || m.height == 0 {
return ""
}
separator := lipgloss.NewStyle().Foreground(separatorColor).Render(
strings.Repeat("─", m.width),
)
menuView := m.input.viewMenu(m.width)
viewportHeight := m.viewportHeight()
var parts []string
parts = append(parts, m.viewport.view(viewportHeight))
parts = append(parts, separator)
if menuView != "" {
parts = append(parts, menuView)
}
parts = append(parts, m.input.viewInput())
parts = append(parts, separator)
parts = append(parts, m.status.view())
return strings.Join(parts, "\n")
}
// handleKey processes keyboard input.
func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.Type {
case tea.KeyEscape:
// Cancel streaming or close menu
if m.input.menuVisible {
m.input.menuVisible = false
return m, nil
}
if m.isStreaming {
return m.cancelStream()
}
// Dismiss picker
if m.viewport.pickerActive {
m.viewport.pickerActive = false
return m, nil
}
return m, nil
case tea.KeyCtrlD:
// If streaming, cancel first; require a fresh Ctrl+D pair to quit
if m.isStreaming {
return m.cancelStream()
}
if m.quitPending {
return m, tea.Quit
}
m.quitPending = true
m.viewport.addInfo("Press Ctrl+D again to quit.")
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
return resetQuitMsg{}
})
case tea.KeyCtrlO:
m.viewport.showSources = !m.viewport.showSources
return m, nil
case tea.KeyEnter:
// If picker is active, handle selection
if m.viewport.pickerActive && len(m.viewport.pickerItems) > 0 {
item := m.viewport.pickerItems[m.viewport.pickerIndex]
m.viewport.pickerActive = false
switch m.viewport.pickerType {
case pickerSession:
return cmdResume(m, item.id)
case pickerAgent:
return cmdSelectAgent(m, item.id)
}
}
case tea.KeyUp:
if m.viewport.pickerActive {
if m.viewport.pickerIndex > 0 {
m.viewport.pickerIndex--
}
return m, nil
}
case tea.KeyDown:
if m.viewport.pickerActive {
if m.viewport.pickerIndex < len(m.viewport.pickerItems)-1 {
m.viewport.pickerIndex++
}
return m, nil
}
case tea.KeyPgUp:
m.viewport.scrollUp(m.viewportHeight() / 2)
return m, nil
case tea.KeyPgDown:
m.viewport.scrollDown(m.viewportHeight() / 2)
return m, nil
case tea.KeyShiftUp:
m.viewport.scrollUp(3)
return m, nil
case tea.KeyShiftDown:
m.viewport.scrollDown(3)
return m, nil
}
// Pass to input
var cmd tea.Cmd
m.input, cmd = m.input.update(msg)
return m, cmd
}
func (m Model) handleSubmit(text string) (tea.Model, tea.Cmd) {
if strings.HasPrefix(text, "/") {
return handleSlashCommand(m, text)
}
return m.sendMessage(text)
}
func (m Model) handleFileDrop(path string) (tea.Model, tea.Cmd) {
return cmdAttach(m, path)
}
func (m Model) cancelStream() (Model, tea.Cmd) {
if m.streamCancel != nil {
m.streamCancel()
}
if m.chatSessionID != nil {
sid := *m.chatSessionID
go m.client.StopChatSession(sid)
}
m, cmd := m.finishStream(nil)
m.viewport.addInfo("Generation stopped.")
return m, cmd
}
func (m Model) sendMessage(message string) (Model, tea.Cmd) {
if m.isStreaming {
return m, nil
}
m.viewport.addUserMessage(message)
m.viewport.startAgent()
// Prepare file descriptors
fileDescs := make([]models.FileDescriptorPayload, len(m.attachedFiles))
copy(fileDescs, m.attachedFiles)
m.attachedFiles = nil
m.input.clearFiles()
m.isStreaming = true
m.agentStarted = false
m.citations = make(map[int]string)
m.status.setStreaming(true)
ctx, cancel := context.WithCancel(context.Background())
m.streamCancel = cancel
ch := m.client.SendMessageStream(
ctx,
message,
m.chatSessionID,
m.agentID,
m.parentMessageID,
fileDescs,
)
m.streamCh = ch
return m, api.WaitForStreamEvent(ch)
}
func (m Model) handleStreamEvent(msg api.StreamEventMsg) (tea.Model, tea.Cmd) {
// Ignore stale events after cancellation
if !m.isStreaming {
return m, nil
}
if msg.Event == nil {
return m, api.WaitForStreamEvent(m.streamCh)
}
switch e := msg.Event.(type) {
case models.SessionCreatedEvent:
m.chatSessionID = &e.ChatSessionID
m.needsRename = true
m.status.setSession(e.ChatSessionID)
case models.MessageIDEvent:
m.parentMessageID = &e.ReservedAgentMessageID
case models.MessageStartEvent:
m.agentStarted = true
case models.MessageDeltaEvent:
m.agentStarted = true
m.viewport.appendToken(e.Content)
case models.SearchStartEvent:
if e.IsInternetSearch {
m.viewport.addInfo("Web search…")
} else {
m.viewport.addInfo("Searching…")
}
case models.SearchQueriesEvent:
if len(e.Queries) > 0 {
queries := e.Queries
if len(queries) > 3 {
queries = queries[:3]
}
parts := make([]string, len(queries))
for i, q := range queries {
parts[i] = "\"" + q + "\""
}
m.viewport.addInfo("Searching: " + strings.Join(parts, ", "))
}
case models.SearchDocumentsEvent:
count := len(e.Documents)
suffix := "s"
if count == 1 {
suffix = ""
}
m.viewport.addInfo("Found " + strconv.Itoa(count) + " document" + suffix)
case models.ReasoningStartEvent:
m.viewport.addInfo("Thinking…")
case models.ReasoningDeltaEvent:
// We don't display reasoning text, just the indicator
case models.ReasoningDoneEvent:
// No-op
case models.CitationEvent:
m.citations[e.CitationNumber] = e.DocumentID
case models.ToolStartEvent:
m.viewport.addInfo("Using " + e.ToolName + "…")
case models.ResearchAgentStartEvent:
m.viewport.addInfo("Researching: " + e.ResearchTask)
case models.DeepResearchPlanDeltaEvent:
m.viewport.appendToken(e.Content)
case models.IntermediateReportDeltaEvent:
m.viewport.appendToken(e.Content)
case models.StopEvent:
return m.finishStream(nil)
case models.ErrorEvent:
m.viewport.addError(e.Error)
return m.finishStream(nil)
}
return m, api.WaitForStreamEvent(m.streamCh)
}
func (m Model) handleStreamDone(msg api.StreamDoneMsg) (tea.Model, tea.Cmd) {
// Ignore if already cancelled
if !m.isStreaming {
return m, nil
}
return m.finishStream(msg.Err)
}
func (m Model) finishStream(err error) (Model, tea.Cmd) {
if m.agentStarted {
m.viewport.finishAgent()
if len(m.citations) > 0 {
m.viewport.addCitations(m.citations)
}
}
m.isStreaming = false
m.agentStarted = false
m.status.setStreaming(false)
m.streamCancel = nil
m.streamCh = nil
// Auto-rename new sessions
if m.needsRename && m.chatSessionID != nil {
m.needsRename = false
sessionID := *m.chatSessionID
client := m.client
go func() {
_, _ = client.RenameChatSession(sessionID, nil)
}()
}
return m, nil
}
func (m Model) handleInitDone(msg InitDoneMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addWarning("Could not load agents. Using default.")
} else {
m.agents = msg.Agents
for _, p := range m.agents {
if p.ID == m.agentID {
m.agentName = p.Name
break
}
}
}
m.status.setServer(m.config.ServerURL)
m.status.setAgent(m.agentName)
return m, nil
}
func (m Model) handleAgentsLoaded(msg AgentsLoadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load agents: " + msg.Err.Error())
return m, nil
}
m.agents = msg.Agents
if len(m.agents) == 0 {
m.viewport.addInfo("No agents available.")
return m, nil
}
m.viewport.addInfo("Select an agent (Enter to select, Esc to cancel):")
var items []pickerItem
for _, p := range m.agents {
label := fmt.Sprintf("%d: %s", p.ID, p.Name)
if p.ID == m.agentID {
label += " *"
}
desc := p.Description
if len(desc) > 50 {
desc = desc[:50] + "..."
}
if desc != "" {
label += " - " + desc
}
items = append(items, pickerItem{
id: strconv.Itoa(p.ID),
label: label,
})
}
m.viewport.showPicker(pickerAgent, items)
return m, nil
}
func (m Model) handleSessionsLoaded(msg SessionsLoadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load sessions: " + msg.Err.Error())
return m, nil
}
if len(msg.Sessions) == 0 {
m.viewport.addInfo("No previous sessions found.")
return m, nil
}
m.viewport.addInfo("Select a session to resume (Enter to select, Esc to cancel):")
var items []pickerItem
for i, s := range msg.Sessions {
if i >= 15 {
break
}
name := "Untitled"
if s.Name != nil && *s.Name != "" {
name = *s.Name
}
sid := s.ID
if len(sid) > 8 {
sid = sid[:8]
}
items = append(items, pickerItem{
id: s.ID,
label: sid + " " + name + " (" + s.Created + ")",
})
}
m.viewport.showPicker(pickerSession, items)
return m, nil
}
func (m Model) handleSessionResumed(msg SessionResumedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Could not load session: " + msg.Err.Error())
return m, nil
}
// Cancel any in-progress stream before replacing the session
if m.isStreaming {
m, _ = m.cancelStream()
}
detail := msg.Detail
m.chatSessionID = &detail.ChatSessionID
m.viewport.clearDisplay()
m.status.setSession(detail.ChatSessionID)
if detail.AgentName != nil {
m.agentName = *detail.AgentName
m.status.setAgent(*detail.AgentName)
}
if detail.AgentID != nil {
m.agentID = *detail.AgentID
}
// Replay messages
for _, chatMsg := range detail.Messages {
switch chatMsg.MessageType {
case "user":
m.viewport.addUserMessage(chatMsg.Message)
case "assistant":
m.viewport.startAgent()
m.viewport.appendToken(chatMsg.Message)
m.viewport.finishAgent()
}
}
// Set parent to last message
if len(detail.Messages) > 0 {
lastID := detail.Messages[len(detail.Messages)-1].MessageID
m.parentMessageID = &lastID
}
desc := "Untitled"
if detail.Description != nil && *detail.Description != "" {
desc = *detail.Description
}
m.viewport.addInfo("Resumed session: " + desc)
return m, nil
}
func (m Model) handleFileUploaded(msg FileUploadedMsg) (tea.Model, tea.Cmd) {
if msg.Err != nil {
m.viewport.addError("Upload failed: " + msg.Err.Error())
return m, nil
}
m.attachedFiles = append(m.attachedFiles, *msg.Descriptor)
m.input.addFile(msg.FileName)
m.viewport.addInfo("Attached: " + msg.FileName)
return m, nil
}
type inputReadyMsg struct{}
type resetQuitMsg struct{}

View File

@@ -0,0 +1,202 @@
package tui
import (
"fmt"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/util"
)
// handleSlashCommand dispatches slash commands and returns updated model + cmd.
func handleSlashCommand(m Model, text string) (Model, tea.Cmd) {
parts := strings.SplitN(text, " ", 2)
command := strings.ToLower(parts[0])
arg := ""
if len(parts) > 1 {
arg = parts[1]
}
switch command {
case "/help":
m.viewport.addInfo(helpText)
return m, nil
case "/new":
return cmdNew(m)
case "/agent":
if arg != "" {
return cmdSelectAgent(m, arg)
}
return cmdShowAgents(m)
case "/attach":
return cmdAttach(m, arg)
case "/sessions", "/resume":
if strings.TrimSpace(arg) != "" {
return cmdResume(m, arg)
}
return cmdSessions(m)
case "/configure":
m.viewport.addInfo("Run 'onyx-cli configure' to change connection settings.")
return m, nil
case "/clear":
return cmdNew(m)
case "/connectors":
url := m.config.ServerURL + "/admin/indexing/status"
if util.OpenBrowser(url) {
m.viewport.addInfo("Opened " + url + " in browser")
} else {
m.viewport.addWarning("Failed to open browser. Visit: " + url)
}
return m, nil
case "/settings":
url := m.config.ServerURL + "/app/settings/general"
if util.OpenBrowser(url) {
m.viewport.addInfo("Opened " + url + " in browser")
} else {
m.viewport.addWarning("Failed to open browser. Visit: " + url)
}
return m, nil
case "/quit":
return m, tea.Quit
default:
m.viewport.addWarning(fmt.Sprintf("Unknown command: %s. Type /help for available commands.", command))
return m, nil
}
}
func cmdNew(m Model) (Model, tea.Cmd) {
m.chatSessionID = nil
parentID := -1
m.parentMessageID = &parentID
m.needsRename = false
m.citations = nil
m.viewport.clearAll()
// Re-add splash as a scrollable entry
viewportHeight := m.viewportHeight()
if viewportHeight < 1 {
viewportHeight = m.height
}
m.viewport.addSplash(viewportHeight)
m.status.setSession("")
return m, nil
}
func cmdShowAgents(m Model) (Model, tea.Cmd) {
m.viewport.addInfo("Loading agents...")
client := m.client
return m, func() tea.Msg {
agents, err := client.ListAgents()
return AgentsLoadedMsg{Agents: agents, Err: err}
}
}
func cmdSelectAgent(m Model, idStr string) (Model, tea.Cmd) {
pid, err := strconv.Atoi(strings.TrimSpace(idStr))
if err != nil {
m.viewport.addWarning("Invalid agent ID. Use a number.")
return m, nil
}
var target *models.AgentSummary
for i := range m.agents {
if m.agents[i].ID == pid {
target = &m.agents[i]
break
}
}
if target == nil {
m.viewport.addWarning(fmt.Sprintf("Agent %d not found. Use /agent to see available agents.", pid))
return m, nil
}
m.agentID = target.ID
m.agentName = target.Name
m.status.setAgent(target.Name)
m.viewport.addInfo("Switched to agent: " + target.Name)
// Save preference
m.config.DefaultAgentID = target.ID
_ = config.Save(m.config)
return m, nil
}
func cmdAttach(m Model, pathStr string) (Model, tea.Cmd) {
if pathStr == "" {
m.viewport.addWarning("Usage: /attach <file_path>")
return m, nil
}
m.viewport.addInfo("Uploading " + pathStr + "...")
client := m.client
return m, func() tea.Msg {
fd, err := client.UploadFile(pathStr)
if err != nil {
return FileUploadedMsg{Err: err, FileName: pathStr}
}
return FileUploadedMsg{Descriptor: fd, FileName: pathStr}
}
}
func cmdSessions(m Model) (Model, tea.Cmd) {
m.viewport.addInfo("Loading sessions...")
client := m.client
return m, func() tea.Msg {
sessions, err := client.ListChatSessions()
return SessionsLoadedMsg{Sessions: sessions, Err: err}
}
}
func cmdResume(m Model, sessionIDStr string) (Model, tea.Cmd) {
client := m.client
return m, func() tea.Msg {
// Try to find session by prefix match
sessions, err := client.ListChatSessions()
if err != nil {
return SessionResumedMsg{Err: err}
}
var targetID string
for _, s := range sessions {
if strings.HasPrefix(s.ID, sessionIDStr) {
targetID = s.ID
break
}
}
if targetID == "" {
// Try as full UUID
targetID = sessionIDStr
}
detail, err := client.GetChatSession(targetID)
if err != nil {
return SessionResumedMsg{Err: fmt.Errorf("session not found: %s", sessionIDStr)}
}
return SessionResumedMsg{Detail: detail}
}
}
// loadAgentsCmd returns a tea.Cmd that loads agents from the API.
func loadAgentsCmd(client *api.Client) tea.Cmd {
return func() tea.Msg {
agents, err := client.ListAgents()
return InitDoneMsg{Agents: agents, Err: err}
}
}

24
cli/internal/tui/help.go Normal file
View File

@@ -0,0 +1,24 @@
package tui
const helpText = `Onyx CLI Commands
/help Show this help message
/new Start a new chat session
/agent List and switch agents
/attach <path> Attach a file to next message
/sessions Browse and resume previous sessions
/clear Clear the chat display
/configure Re-run connection setup
/connectors Open connectors page in browser
/settings Open Onyx settings in browser
/quit Exit Onyx CLI
Keyboard Shortcuts
Enter Send message
Escape Cancel current generation
Ctrl+O Toggle source citations
Ctrl+D Quit (press twice)
Scroll Up/Down Mouse wheel or Shift+Up/Down
Page Up/Down Scroll half page
`

242
cli/internal/tui/input.go Normal file
View File

@@ -0,0 +1,242 @@
package tui
import (
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
// slashCommand defines a slash command with its description.
type slashCommand struct {
command string
description string
}
var slashCommands = []slashCommand{
{"/help", "Show help message"},
{"/new", "Start a new chat session"},
{"/agent", "List and switch agents"},
{"/attach", "Attach a file to next message"},
{"/sessions", "Browse and resume previous sessions"},
{"/clear", "Clear the chat display"},
{"/configure", "Re-run connection setup"},
{"/connectors", "Open connectors in browser"},
{"/settings", "Open settings in browser"},
{"/quit", "Exit Onyx CLI"},
}
// Commands that take arguments (filled in with trailing space on Tab/Enter).
var argCommands = map[string]bool{
"/attach": true,
}
// inputModel manages the text input and slash command menu.
type inputModel struct {
textInput textinput.Model
menuVisible bool
menuItems []slashCommand
menuIndex int
attachedFiles []string
}
func newInputModel() inputModel {
ti := textinput.New()
ti.Prompt = "" // We render our own prompt in viewInput()
ti.Placeholder = "Send a message…"
ti.CharLimit = 10000
// Don't focus here — focus after first WindowSizeMsg to avoid
// capturing terminal init escape sequences as input.
return inputModel{
textInput: ti,
}
}
func (m inputModel) update(msg tea.Msg) (inputModel, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
return m.handleKey(msg)
}
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m = m.updateMenu()
return m, cmd
}
func (m inputModel) handleKey(msg tea.KeyMsg) (inputModel, tea.Cmd) {
switch msg.Type {
case tea.KeyUp:
if m.menuVisible && m.menuIndex > 0 {
m.menuIndex--
return m, nil
}
case tea.KeyDown:
if m.menuVisible && m.menuIndex < len(m.menuItems)-1 {
m.menuIndex++
return m, nil
}
case tea.KeyTab:
if m.menuVisible && len(m.menuItems) > 0 {
cmd := m.menuItems[m.menuIndex].command
if argCommands[cmd] {
m.textInput.SetValue(cmd + " ")
m.textInput.SetCursor(len(cmd) + 1)
} else {
m.textInput.SetValue(cmd)
m.textInput.SetCursor(len(cmd))
}
m.menuVisible = false
return m, nil
}
case tea.KeyEnter:
if m.menuVisible && len(m.menuItems) > 0 {
cmd := m.menuItems[m.menuIndex].command
if argCommands[cmd] {
m.textInput.SetValue(cmd + " ")
m.textInput.SetCursor(len(cmd) + 1)
m.menuVisible = false
return m, nil
}
// Execute immediately
m.textInput.SetValue("")
m.menuVisible = false
return m, func() tea.Msg { return submitMsg{text: cmd} }
}
text := strings.TrimSpace(m.textInput.Value())
if text == "" {
return m, nil
}
// Check for file path (drag-and-drop)
if dropped := detectFileDrop(text); dropped != "" {
m.textInput.SetValue("")
return m, func() tea.Msg { return fileDropMsg{path: dropped} }
}
m.textInput.SetValue("")
m.menuVisible = false
return m, func() tea.Msg { return submitMsg{text: text} }
case tea.KeyEscape:
if m.menuVisible {
m.menuVisible = false
return m, nil
}
}
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m = m.updateMenu()
return m, cmd
}
func (m inputModel) updateMenu() inputModel {
val := strings.TrimSpace(m.textInput.Value())
if strings.HasPrefix(val, "/") && !strings.Contains(val, " ") {
needle := strings.ToLower(val)
var filtered []slashCommand
for _, sc := range slashCommands {
if strings.HasPrefix(sc.command, needle) {
filtered = append(filtered, sc)
}
}
if len(filtered) > 0 {
m.menuVisible = true
m.menuItems = filtered
if m.menuIndex >= len(filtered) {
m.menuIndex = 0
}
} else {
m.menuVisible = false
}
} else {
m.menuVisible = false
}
return m
}
func (m *inputModel) addFile(name string) {
m.attachedFiles = append(m.attachedFiles, name)
}
func (m *inputModel) clearFiles() {
m.attachedFiles = nil
}
// submitMsg is sent when user submits text.
type submitMsg struct {
text string
}
// fileDropMsg is sent when a file path is detected.
type fileDropMsg struct {
path string
}
// detectFileDrop checks if the text looks like a file path.
func detectFileDrop(text string) string {
cleaned := strings.Trim(text, "'\"")
if cleaned == "" {
return ""
}
// Only treat as a file drop if it looks explicitly path-like
if !strings.HasPrefix(cleaned, "/") && !strings.HasPrefix(cleaned, "~") &&
!strings.HasPrefix(cleaned, "./") && !strings.HasPrefix(cleaned, "../") {
return ""
}
// Expand ~ to home dir
if strings.HasPrefix(cleaned, "~") {
home, err := os.UserHomeDir()
if err == nil {
cleaned = filepath.Join(home, cleaned[1:])
}
}
abs, err := filepath.Abs(cleaned)
if err != nil {
return ""
}
info, err := os.Stat(abs)
if err != nil {
return ""
}
if info.IsDir() {
return ""
}
return abs
}
// viewMenu renders the slash command menu.
func (m inputModel) viewMenu(width int) string {
if !m.menuVisible || len(m.menuItems) == 0 {
return ""
}
var lines []string
for i, item := range m.menuItems {
prefix := " "
if i == m.menuIndex {
prefix = "> "
}
line := prefix + item.command + " " + statusMsgStyle.Render(item.description)
lines = append(lines, line)
}
return strings.Join(lines, "\n")
}
// viewInput renders the input line with prompt and optional file badges.
func (m inputModel) viewInput() string {
var parts []string
if len(m.attachedFiles) > 0 {
badges := strings.Join(m.attachedFiles, "] [")
parts = append(parts, statusMsgStyle.Render("Attached: ["+badges+"]"))
}
parts = append(parts, inputPrompt+m.textInput.View())
return strings.Join(parts, "\n")
}

View File

@@ -0,0 +1,46 @@
package tui
import (
"github.com/onyx-dot-app/onyx/cli/internal/models"
)
// InitDoneMsg signals that async initialization is complete.
type InitDoneMsg struct {
Agents []models.AgentSummary
Err error
}
// SessionsLoadedMsg carries loaded chat sessions.
type SessionsLoadedMsg struct {
Sessions []models.ChatSessionDetails
Err error
}
// SessionResumedMsg carries a loaded session detail.
type SessionResumedMsg struct {
Detail *models.ChatSessionDetailResponse
Err error
}
// FileUploadedMsg carries an uploaded file descriptor.
type FileUploadedMsg struct {
Descriptor *models.FileDescriptorPayload
FileName string
Err error
}
// AgentsLoadedMsg carries freshly fetched agents from the API.
type AgentsLoadedMsg struct {
Agents []models.AgentSummary
Err error
}
// InfoMsg is a simple informational message for display.
type InfoMsg struct {
Text string
}
// ErrorMsg wraps an error for display.
type ErrorMsg struct {
Err error
}

View File

@@ -0,0 +1,79 @@
package tui
import (
"strings"
"github.com/charmbracelet/lipgloss"
)
const onyxLogo = ` ██████╗ ███╗ ██╗██╗ ██╗██╗ ██╗
██╔═══██╗████╗ ██║╚██╗ ██╔╝╚██╗██╔╝
██║ ██║██╔██╗ ██║ ╚████╔╝ ╚███╔╝
██║ ██║██║╚██╗██║ ╚██╔╝ ██╔██╗
╚██████╔╝██║ ╚████║ ██║ ██╔╝ ██╗
╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝`
const tagline = "Your terminal interface for Onyx"
const splashHint = "Type a message to begin · /help for commands"
// renderSplash renders the splash screen centered for the given dimensions.
func renderSplash(width, height int) string {
// Render the logo as a single block (don't center individual lines)
logo := splashStyle.Render(onyxLogo)
// Center tagline and hint relative to the logo block width
logoWidth := lipgloss.Width(logo)
tag := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
taglineStyle.Render(tagline),
)
hint := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
hintStyle.Render(splashHint),
)
block := lipgloss.JoinVertical(lipgloss.Left, logo, "", tag, hint)
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, block)
}
// RenderSplashOnboarding renders splash for the terminal onboarding screen.
func RenderSplashOnboarding(width, height int) string {
// Render the logo as a styled block, then center it as a unit
styledLogo := splashStyle.Render(onyxLogo)
logoWidth := lipgloss.Width(styledLogo)
logoLines := strings.Split(styledLogo, "\n")
logoHeight := len(logoLines)
contentHeight := logoHeight + 2 // logo + blank + tagline
topPad := (height - contentHeight) / 2
if topPad < 1 {
topPad = 1
}
// Center the entire logo block horizontally
blockPad := (width - logoWidth) / 2
if blockPad < 0 {
blockPad = 0
}
var b strings.Builder
for i := 0; i < topPad; i++ {
b.WriteByte('\n')
}
for _, line := range logoLines {
b.WriteString(strings.Repeat(" ", blockPad))
b.WriteString(line)
b.WriteByte('\n')
}
b.WriteByte('\n')
tagPad := (width - len(tagline)) / 2
if tagPad < 0 {
tagPad = 0
}
b.WriteString(strings.Repeat(" ", tagPad))
b.WriteString(taglineStyle.Render(tagline))
b.WriteByte('\n')
return b.String()
}

View File

@@ -0,0 +1,60 @@
package tui
import (
"strings"
"github.com/charmbracelet/lipgloss"
)
// statusBar manages the footer status display.
type statusBar struct {
agentName string
serverURL string
sessionID string
streaming bool
width int
}
func newStatusBar() statusBar {
return statusBar{
agentName: "Default",
}
}
func (s *statusBar) setAgent(name string) { s.agentName = name }
func (s *statusBar) setServer(url string) { s.serverURL = url }
func (s *statusBar) setSession(id string) {
if len(id) > 8 {
id = id[:8]
}
s.sessionID = id
}
func (s *statusBar) setStreaming(v bool) { s.streaming = v }
func (s *statusBar) setWidth(w int) { s.width = w }
func (s statusBar) view() string {
var leftParts []string
if s.serverURL != "" {
leftParts = append(leftParts, s.serverURL)
}
name := s.agentName
if name == "" {
name = "Default"
}
leftParts = append(leftParts, name)
left := statusBarStyle.Render(strings.Join(leftParts, " · "))
right := "Ctrl+D to quit"
if s.streaming {
right = "Esc to cancel"
}
rightRendered := statusBarStyle.Render(right)
// Fill space between left and right
gap := s.width - lipgloss.Width(left) - lipgloss.Width(rightRendered)
if gap < 1 {
gap = 1
}
return left + strings.Repeat(" ", gap) + rightRendered
}

View File

@@ -0,0 +1,29 @@
package tui
import "github.com/charmbracelet/lipgloss"
var (
// Colors
accentColor = lipgloss.Color("#6c8ebf")
dimColor = lipgloss.Color("#555577")
errorColor = lipgloss.Color("#ff5555")
splashColor = lipgloss.Color("#7C6AEF")
separatorColor = lipgloss.Color("#333355")
citationColor = lipgloss.Color("#666688")
// Styles
userPrefixStyle = lipgloss.NewStyle().Foreground(dimColor)
agentDot = lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("◉")
infoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#b0b0cc"))
dimInfoStyle = lipgloss.NewStyle().Foreground(dimColor)
statusMsgStyle = dimInfoStyle // used for slash menu descriptions, file badges
errorStyle = lipgloss.NewStyle().Foreground(errorColor).Bold(true)
warnStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
citationStyle = lipgloss.NewStyle().Foreground(citationColor)
statusBarStyle = lipgloss.NewStyle().Foreground(dimColor)
inputPrompt = lipgloss.NewStyle().Foreground(accentColor).Render(" ")
splashStyle = lipgloss.NewStyle().Foreground(splashColor).Bold(true)
taglineStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#A0A0A0"))
hintStyle = lipgloss.NewStyle().Foreground(dimColor)
)

View File

@@ -0,0 +1,419 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/glamour"
"github.com/charmbracelet/glamour/styles"
"github.com/charmbracelet/lipgloss"
)
// entryKind is the type of chat entry.
type entryKind int
const (
entryUser entryKind = iota
entryAgent
entryInfo
entryError
entryCitation
)
// chatEntry is a single rendered entry in the chat history.
type chatEntry struct {
kind entryKind
content string // raw content (for agent: the markdown source)
rendered string // pre-rendered output
citations []string // citation lines (for citation entries)
}
// pickerKind distinguishes what the picker is selecting.
type pickerKind int
const (
pickerSession pickerKind = iota
pickerAgent
)
// pickerItem is a selectable item in the picker.
type pickerItem struct {
id string
label string
}
// viewport manages the chat display.
type viewport struct {
entries []chatEntry
width int
streaming bool
streamBuf string
showSources bool
renderer *glamour.TermRenderer
pickerItems []pickerItem
pickerActive bool
pickerIndex int
pickerType pickerKind
scrollOffset int // lines scrolled up from bottom (0 = pinned to bottom)
lastMaxScroll int // cached from last render for clamping in scrollUp
}
// newMarkdownRenderer creates a Glamour renderer with zero left margin.
func newMarkdownRenderer(width int) *glamour.TermRenderer {
style := styles.DarkStyleConfig
zero := uint(0)
style.Document.Margin = &zero
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(style),
glamour.WithWordWrap(width-4),
)
return r
}
func newViewport(width int) *viewport {
return &viewport{
width: width,
renderer: newMarkdownRenderer(width),
}
}
func (v *viewport) addSplash(height int) {
splash := renderSplash(v.width, height)
v.entries = append(v.entries, chatEntry{
kind: entryInfo,
rendered: splash,
})
}
func (v *viewport) setWidth(w int) {
v.width = w
v.renderer = newMarkdownRenderer(w)
}
func (v *viewport) addUserMessage(msg string) {
rendered := "\n" + userPrefixStyle.Render(" ") + msg
v.entries = append(v.entries, chatEntry{
kind: entryUser,
content: msg,
rendered: rendered,
})
}
func (v *viewport) startAgent() {
v.streaming = true
v.streamBuf = ""
// Add a blank-line spacer entry before the agent message
v.entries = append(v.entries, chatEntry{kind: entryInfo, rendered: ""})
}
func (v *viewport) appendToken(token string) {
v.streamBuf += token
// Only auto-scroll when already pinned to bottom; preserve user's scroll position
if v.scrollOffset == 0 {
v.scrollOffset = 0
}
}
func (v *viewport) finishAgent() {
if v.streamBuf == "" {
v.streaming = false
// Remove the blank spacer entry added by startAgent()
if len(v.entries) > 0 && v.entries[len(v.entries)-1].kind == entryInfo && v.entries[len(v.entries)-1].rendered == "" {
v.entries = v.entries[:len(v.entries)-1]
}
return
}
// Render markdown with Glamour (zero left margin style)
rendered := v.renderMarkdown(v.streamBuf)
rendered = strings.TrimLeft(rendered, "\n")
rendered = strings.TrimRight(rendered, "\n")
lines := strings.Split(rendered, "\n")
// Prefix first line with dot, indent continuation lines
if len(lines) > 0 {
lines[0] = agentDot + " " + lines[0]
for i := 1; i < len(lines); i++ {
lines[i] = " " + lines[i]
}
}
rendered = strings.Join(lines, "\n")
v.entries = append(v.entries, chatEntry{
kind: entryAgent,
content: v.streamBuf,
rendered: rendered,
})
v.streaming = false
v.streamBuf = ""
}
func (v *viewport) renderMarkdown(md string) string {
if v.renderer == nil {
return md
}
out, err := v.renderer.Render(md)
if err != nil {
return md
}
return out
}
func (v *viewport) addInfo(msg string) {
rendered := infoStyle.Render("● " + msg)
v.entries = append(v.entries, chatEntry{
kind: entryInfo,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addWarning(msg string) {
rendered := warnStyle.Render("● " + msg)
v.entries = append(v.entries, chatEntry{
kind: entryError,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addError(msg string) {
rendered := errorStyle.Render("● Error: ") + msg
v.entries = append(v.entries, chatEntry{
kind: entryError,
content: msg,
rendered: rendered,
})
}
func (v *viewport) addCitations(citations map[int]string) {
if len(citations) == 0 {
return
}
keys := make([]int, 0, len(citations))
for k := range citations {
keys = append(keys, k)
}
sort.Ints(keys)
var parts []string
for _, num := range keys {
parts = append(parts, fmt.Sprintf("[%d] %s", num, citations[num]))
}
text := fmt.Sprintf("Sources (%d): %s", len(citations), strings.Join(parts, " "))
var citLines []string
citLines = append(citLines, text)
v.entries = append(v.entries, chatEntry{
kind: entryCitation,
content: text,
rendered: citationStyle.Render("● "+text),
citations: citLines,
})
}
func (v *viewport) showPicker(kind pickerKind, items []pickerItem) {
v.pickerItems = items
v.pickerType = kind
v.pickerActive = true
v.pickerIndex = 0
}
func (v *viewport) scrollUp(n int) {
v.scrollOffset += n
if v.scrollOffset > v.lastMaxScroll {
v.scrollOffset = v.lastMaxScroll
}
}
func (v *viewport) scrollDown(n int) {
v.scrollOffset -= n
if v.scrollOffset < 0 {
v.scrollOffset = 0
}
}
func (v *viewport) clearAll() {
v.entries = nil
v.streaming = false
v.streamBuf = ""
v.pickerItems = nil
v.pickerActive = false
v.scrollOffset = 0
}
func (v *viewport) clearDisplay() {
v.entries = nil
v.scrollOffset = 0
v.streaming = false
v.streamBuf = ""
}
// pickerTitle returns a title for the current picker kind.
func (v *viewport) pickerTitle() string {
switch v.pickerType {
case pickerAgent:
return "Select Agent"
case pickerSession:
return "Resume Session"
default:
return "Select"
}
}
// renderPicker renders the picker as a bordered overlay.
func (v *viewport) renderPicker(width, height int) string {
title := v.pickerTitle()
// Determine picker dimensions
maxItems := len(v.pickerItems)
panelWidth := width - 4
if panelWidth < 30 {
panelWidth = 30
}
if panelWidth > 70 {
panelWidth = 70
}
innerWidth := panelWidth - 4 // border + padding
// Visible window of items (scroll if too many)
maxVisible := height - 6 // room for border, title, hint
if maxVisible < 3 {
maxVisible = 3
}
if maxVisible > maxItems {
maxVisible = maxItems
}
// Calculate scroll window around current index
startIdx := 0
if v.pickerIndex >= maxVisible {
startIdx = v.pickerIndex - maxVisible + 1
}
endIdx := startIdx + maxVisible
if endIdx > maxItems {
endIdx = maxItems
startIdx = endIdx - maxVisible
if startIdx < 0 {
startIdx = 0
}
}
var itemLines []string
for i := startIdx; i < endIdx; i++ {
item := v.pickerItems[i]
label := item.label
labelRunes := []rune(label)
if len(labelRunes) > innerWidth-4 {
label = string(labelRunes[:innerWidth-7]) + "..."
}
if i == v.pickerIndex {
line := lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("> " + label)
itemLines = append(itemLines, line)
} else {
itemLines = append(itemLines, " "+label)
}
}
hint := lipgloss.NewStyle().Foreground(dimColor).Render("↑↓ navigate • enter select • esc cancel")
body := strings.Join(itemLines, "\n") + "\n\n" + hint
panel := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor).
Padding(1, 2).
Width(panelWidth).
Render(body)
titleRendered := lipgloss.NewStyle().
Foreground(accentColor).
Bold(true).
Render(" " + title + " ")
// Build top border manually to avoid ANSI-corrupted rune slicing.
// panelWidth+2 accounts for the left and right border characters.
borderColor := lipgloss.NewStyle().Foreground(accentColor)
titleWidth := lipgloss.Width(titleRendered)
rightDashes := panelWidth + 2 - 3 - titleWidth // total - "╭─" - "╮" - title
if rightDashes < 0 {
rightDashes = 0
}
topBorder := borderColor.Render("╭─") + titleRendered +
borderColor.Render(strings.Repeat("─", rightDashes)+"╮")
panelLines := strings.Split(panel, "\n")
if len(panelLines) > 0 {
panelLines[0] = topBorder
}
panel = strings.Join(panelLines, "\n")
// Center the panel in the viewport
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, panel)
}
// view renders the full viewport content.
func (v *viewport) view(height int) string {
// If picker is active, render it as an overlay
if v.pickerActive && len(v.pickerItems) > 0 {
return v.renderPicker(v.width, height)
}
var lines []string
for _, e := range v.entries {
if e.kind == entryCitation && !v.showSources {
continue
}
lines = append(lines, e.rendered)
}
// Streaming buffer (plain text, not markdown)
if v.streaming && v.streamBuf != "" {
bufLines := strings.Split(v.streamBuf, "\n")
if len(bufLines) > 0 {
bufLines[0] = agentDot + " " + bufLines[0]
for i := 1; i < len(bufLines); i++ {
bufLines[i] = " " + bufLines[i]
}
}
lines = append(lines, strings.Join(bufLines, "\n"))
} else if v.streaming {
lines = append(lines, agentDot+" ")
}
content := strings.Join(lines, "\n")
contentLines := strings.Split(content, "\n")
total := len(contentLines)
// Cache max scroll for clamping in scrollUp()
maxScroll := total - height
if maxScroll < 0 {
maxScroll = 0
}
v.lastMaxScroll = maxScroll
scrollOffset := v.scrollOffset
if scrollOffset > maxScroll {
scrollOffset = maxScroll
}
if total <= height {
// Content fits — pad with empty lines at top to push content down
padding := make([]string, height-total)
for i := range padding {
padding[i] = ""
}
contentLines = append(padding, contentLines...)
} else {
// Show a window: end is (total - scrollOffset), start is (end - height)
end := total - scrollOffset
start := end - height
if start < 0 {
start = 0
}
contentLines = contentLines[start:end]
}
return strings.Join(contentLines, "\n")
}

View File

@@ -0,0 +1,264 @@
package tui
import (
"regexp"
"strings"
"testing"
)
// stripANSI removes ANSI escape sequences for test comparisons.
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
func stripANSI(s string) string {
return ansiRegex.ReplaceAllString(s, "")
}
func TestAddUserMessage(t *testing.T) {
v := newViewport(80)
v.addUserMessage("hello world")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryUser {
t.Errorf("expected entryUser, got %d", e.kind)
}
if e.content != "hello world" {
t.Errorf("expected content 'hello world', got %q", e.content)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "") {
t.Errorf("expected rendered to contain , got %q", plain)
}
if !strings.Contains(plain, "hello world") {
t.Errorf("expected rendered to contain message text, got %q", plain)
}
}
func TestStartAndFinishAgent(t *testing.T) {
v := newViewport(80)
v.startAgent()
if !v.streaming {
t.Error("expected streaming to be true after startAgent")
}
if len(v.entries) != 1 {
t.Fatalf("expected 1 spacer entry, got %d", len(v.entries))
}
if v.entries[0].rendered != "" {
t.Errorf("expected empty spacer, got %q", v.entries[0].rendered)
}
v.appendToken("Hello ")
v.appendToken("world")
if v.streamBuf != "Hello world" {
t.Errorf("expected streamBuf 'Hello world', got %q", v.streamBuf)
}
v.finishAgent()
if v.streaming {
t.Error("expected streaming to be false after finishAgent")
}
if v.streamBuf != "" {
t.Errorf("expected empty streamBuf after finish, got %q", v.streamBuf)
}
if len(v.entries) != 2 {
t.Fatalf("expected 2 entries (spacer + agent), got %d", len(v.entries))
}
e := v.entries[1]
if e.kind != entryAgent {
t.Errorf("expected entryAgent, got %d", e.kind)
}
if e.content != "Hello world" {
t.Errorf("expected content 'Hello world', got %q", e.content)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "Hello world") {
t.Errorf("expected rendered to contain message text, got %q", plain)
}
}
func TestFinishAgentNoPadding(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.appendToken("Test message")
v.finishAgent()
e := v.entries[1]
// First line should not start with plain spaces (ANSI codes are OK)
plain := stripANSI(e.rendered)
lines := strings.Split(plain, "\n")
if strings.HasPrefix(lines[0], " ") {
t.Errorf("first line should not start with spaces, got %q", lines[0])
}
}
func TestFinishAgentMultiline(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.appendToken("Line one\n\nLine three")
v.finishAgent()
e := v.entries[1]
plain := stripANSI(e.rendered)
// Glamour may merge or reformat lines; just check content is present
if !strings.Contains(plain, "Line one") {
t.Errorf("expected 'Line one' in rendered, got %q", plain)
}
if !strings.Contains(plain, "Line three") {
t.Errorf("expected 'Line three' in rendered, got %q", plain)
}
}
func TestFinishAgentEmpty(t *testing.T) {
v := newViewport(80)
v.startAgent()
v.finishAgent()
if v.streaming {
t.Error("expected streaming to be false")
}
if len(v.entries) != 0 {
t.Errorf("expected 0 entries (spacer removed), got %d", len(v.entries))
}
}
func TestAddInfo(t *testing.T) {
v := newViewport(80)
v.addInfo("test info")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryInfo {
t.Errorf("expected entryInfo, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if strings.HasPrefix(plain, " ") {
t.Errorf("info should not have leading spaces, got %q", plain)
}
}
func TestAddError(t *testing.T) {
v := newViewport(80)
v.addError("something broke")
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryError {
t.Errorf("expected entryError, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "something broke") {
t.Errorf("expected error message in rendered, got %q", plain)
}
}
func TestAddCitations(t *testing.T) {
v := newViewport(80)
v.addCitations(map[int]string{1: "doc-a", 2: "doc-b"})
if len(v.entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(v.entries))
}
e := v.entries[0]
if e.kind != entryCitation {
t.Errorf("expected entryCitation, got %d", e.kind)
}
plain := stripANSI(e.rendered)
if !strings.Contains(plain, "Sources (2)") {
t.Errorf("expected sources count in rendered, got %q", plain)
}
if strings.HasPrefix(plain, " ") {
t.Errorf("citation should not have leading spaces, got %q", plain)
}
}
func TestAddCitationsEmpty(t *testing.T) {
v := newViewport(80)
v.addCitations(map[int]string{})
if len(v.entries) != 0 {
t.Errorf("expected no entries for empty citations, got %d", len(v.entries))
}
}
func TestCitationVisibility(t *testing.T) {
v := newViewport(80)
v.addInfo("hello")
v.addCitations(map[int]string{1: "doc"})
v.showSources = false
view := v.view(20)
plain := stripANSI(view)
if strings.Contains(plain, "Sources") {
t.Error("expected citations hidden when showSources=false")
}
v.showSources = true
view = v.view(20)
plain = stripANSI(view)
if !strings.Contains(plain, "Sources") {
t.Error("expected citations visible when showSources=true")
}
}
func TestClearAll(t *testing.T) {
v := newViewport(80)
v.addUserMessage("test")
v.startAgent()
v.appendToken("response")
v.clearAll()
if len(v.entries) != 0 {
t.Errorf("expected no entries after clearAll, got %d", len(v.entries))
}
if v.streaming {
t.Error("expected streaming=false after clearAll")
}
if v.streamBuf != "" {
t.Errorf("expected empty streamBuf after clearAll, got %q", v.streamBuf)
}
}
func TestClearDisplay(t *testing.T) {
v := newViewport(80)
v.addUserMessage("test")
v.clearDisplay()
if len(v.entries) != 0 {
t.Errorf("expected no entries after clearDisplay, got %d", len(v.entries))
}
}
func TestViewPadsShortContent(t *testing.T) {
v := newViewport(80)
v.addInfo("hello")
view := v.view(10)
lines := strings.Split(view, "\n")
if len(lines) != 10 {
t.Errorf("expected 10 lines (padded), got %d", len(lines))
}
}
func TestViewTruncatesTallContent(t *testing.T) {
v := newViewport(80)
for i := 0; i < 20; i++ {
v.addInfo("line")
}
view := v.view(5)
lines := strings.Split(view, "\n")
if len(lines) != 5 {
t.Errorf("expected 5 lines (truncated), got %d", len(lines))
}
}

View File

@@ -0,0 +1,29 @@
// Package util provides shared utility functions.
package util
import (
"os/exec"
"runtime"
)
// OpenBrowser opens the given URL in the user's default browser.
// Returns true if the browser was launched successfully.
func OpenBrowser(url string) bool {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "linux":
cmd = exec.Command("xdg-open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
}
if cmd != nil {
if err := cmd.Start(); err == nil {
// Reap the child process to avoid zombies.
go func() { _ = cmd.Wait() }()
return true
}
}
return false
}

View File

@@ -0,0 +1,13 @@
// Package util provides shared utilities for the Onyx CLI.
package util
import "github.com/charmbracelet/lipgloss"
// Shared text styles used across the CLI.
var (
BoldStyle = lipgloss.NewStyle().Bold(true)
DimStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#555577"))
GreenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00cc66")).Bold(true)
RedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff5555")).Bold(true)
YellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
)

23
cli/main.go Normal file
View File

@@ -0,0 +1,23 @@
package main
import (
"fmt"
"os"
"github.com/onyx-dot-app/onyx/cli/cmd"
)
var (
version = "dev"
commit = "none"
)
func main() {
cmd.Version = version
cmd.Commit = commit
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}

View File

@@ -42,7 +42,7 @@ import SvgStar from "@opal/icons/star";
## Usage inside Content
Tag can be rendered as an accessory inside `Content`'s ContentMd via the `tag` prop:
Tag can be rendered as an accessory inside `Content`'s LabelLayout via the `tag` prop:
```tsx
import { Content } from "@opal/layouts";

View File

@@ -1,6 +1,6 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 150ms ease-in-out;
transition: opacity 200ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {

View File

@@ -1,200 +0,0 @@
"use client";
import { Button } from "@opal/components/buttons/Button/components";
import type { SizeVariant } from "@opal/shared";
import SvgEdit from "@opal/icons/edit";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type ContentLgSizePreset = "headline" | "section";
interface ContentLgPresetConfig {
/** Icon width/height (CSS value). */
iconSize: string;
/** Tailwind padding class for the icon container. */
iconContainerPadding: string;
/** Gap between icon container and content (CSS value). */
gap: string;
/** Tailwind font class for the title. */
titleFont: string;
/** Title line-height — also used as icon container min-height (CSS value). */
lineHeight: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: SizeVariant;
/** Tailwind padding class for the edit button container. */
editButtonPadding: string;
}
interface ContentLgProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
/** Optional description below the title. */
description?: string;
/** Enable inline editing of the title. */
editable?: boolean;
/** Called when the user commits an edit. */
onTitleChange?: (newTitle: string) => void;
/** Size preset. Default: `"headline"`. */
sizePreset?: ContentLgSizePreset;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const CONTENT_LG_PRESETS: Record<ContentLgSizePreset, ContentLgPresetConfig> = {
headline: {
iconSize: "2rem",
iconContainerPadding: "p-0.5",
gap: "0.25rem",
titleFont: "font-heading-h2",
lineHeight: "2.25rem",
editButtonSize: "md",
editButtonPadding: "p-1",
},
section: {
iconSize: "1.25rem",
iconContainerPadding: "p-1",
gap: "0rem",
titleFont: "font-heading-h3-muted",
lineHeight: "1.75rem",
editButtonSize: "sm",
editButtonPadding: "p-0.5",
},
};
// ---------------------------------------------------------------------------
// ContentLg
// ---------------------------------------------------------------------------
function ContentLg({
sizePreset = "headline",
icon: Icon,
title,
description,
editable,
onTitleChange,
}: ContentLgProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const inputRef = useRef<HTMLInputElement>(null);
const config = CONTENT_LG_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
setEditing(false);
}
return (
<div className="opal-content-lg" style={{ gap: config.gap }}>
{Icon && (
<div
className={cn(
"opal-content-lg-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className="opal-content-lg-icon"
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div className="opal-content-lg-body">
<div className="opal-content-lg-title-row">
{editing ? (
<div className="opal-content-lg-input-sizer">
<span
className={cn("opal-content-lg-input-mirror", config.titleFont)}
>
{editValue || "\u00A0"}
</span>
<input
ref={inputRef}
className={cn(
"opal-content-lg-input",
config.titleFont,
"text-text-04"
)}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
size={1}
autoFocus
onFocus={(e) => e.currentTarget.select()}
onBlur={commit}
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditing(false);
}
}}
style={{ height: config.lineHeight }}
/>
</div>
) : (
<span
className={cn(
"opal-content-lg-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
</span>
)}
{editable && !editing && (
<div
className={cn(
"opal-content-lg-edit-button",
config.editButtonPadding
)}
>
<Button
icon={SvgEdit}
prominence="internal"
size={config.editButtonSize}
tooltip="Edit"
tooltipSide="right"
onClick={startEditing}
/>
</div>
)}
</div>
{description && (
<div className="opal-content-lg-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
</div>
);
}
export { ContentLg, type ContentLgProps, type ContentLgSizePreset };

View File

@@ -1,279 +0,0 @@
"use client";
import { Button } from "@opal/components/buttons/Button/components";
import { Tag, type TagProps } from "@opal/components/Tag/components";
import type { SizeVariant } from "@opal/shared";
import SvgAlertCircle from "@opal/icons/alert-circle";
import SvgAlertTriangle from "@opal/icons/alert-triangle";
import SvgEdit from "@opal/icons/edit";
import SvgXOctagon from "@opal/icons/x-octagon";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type ContentMdSizePreset = "main-content" | "main-ui" | "secondary";
type ContentMdAuxIcon = "info-gray" | "info-blue" | "warning" | "error";
interface ContentMdPresetConfig {
iconSize: string;
iconContainerPadding: string;
iconColorClass: string;
titleFont: string;
lineHeight: string;
gap: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: SizeVariant;
editButtonPadding: string;
optionalFont: string;
/** Aux icon size = lineHeight 2 × p-0.5. */
auxIconSize: string;
}
interface ContentMdProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
/** Optional description text below the title. */
description?: string;
/** Enable inline editing of the title. */
editable?: boolean;
/** Called when the user commits an edit. */
onTitleChange?: (newTitle: string) => void;
/** When `true`, renders "(Optional)" beside the title. */
optional?: boolean;
/** Auxiliary status icon rendered beside the title. */
auxIcon?: ContentMdAuxIcon;
/** Tag rendered beside the title. */
tag?: TagProps;
/** Size preset. Default: `"main-ui"`. */
sizePreset?: ContentMdSizePreset;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
"main-content": {
iconSize: "1rem",
iconContainerPadding: "p-1",
iconColorClass: "text-text-04",
titleFont: "font-main-content-emphasis",
lineHeight: "1.5rem",
gap: "0.125rem",
editButtonSize: "sm",
editButtonPadding: "p-0",
optionalFont: "font-main-content-muted",
auxIconSize: "1.25rem",
},
"main-ui": {
iconSize: "1rem",
iconContainerPadding: "p-0.5",
iconColorClass: "text-text-03",
titleFont: "font-main-ui-action",
lineHeight: "1.25rem",
gap: "0.25rem",
editButtonSize: "xs",
editButtonPadding: "p-0",
optionalFont: "font-main-ui-muted",
auxIconSize: "1rem",
},
secondary: {
iconSize: "0.75rem",
iconContainerPadding: "p-0.5",
iconColorClass: "text-text-04",
titleFont: "font-secondary-action",
lineHeight: "1rem",
gap: "0.125rem",
editButtonSize: "2xs",
editButtonPadding: "p-0",
optionalFont: "font-secondary-action",
auxIconSize: "0.75rem",
},
};
// ---------------------------------------------------------------------------
// ContentMd
// ---------------------------------------------------------------------------
const AUX_ICON_CONFIG: Record<
ContentMdAuxIcon,
{ icon: IconFunctionComponent; colorClass: string }
> = {
"info-gray": { icon: SvgAlertCircle, colorClass: "text-text-02" },
"info-blue": { icon: SvgAlertCircle, colorClass: "text-status-info-05" },
warning: { icon: SvgAlertTriangle, colorClass: "text-status-warning-05" },
error: { icon: SvgXOctagon, colorClass: "text-status-error-05" },
};
function ContentMd({
icon: Icon,
title,
description,
editable,
onTitleChange,
optional,
auxIcon,
tag,
sizePreset = "main-ui",
}: ContentMdProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const inputRef = useRef<HTMLInputElement>(null);
const config = CONTENT_MD_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
setEditing(false);
}
return (
<div className="opal-content-md" style={{ gap: config.gap }}>
{Icon && (
<div
className={cn(
"opal-content-md-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className={cn("opal-content-md-icon", config.iconColorClass)}
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div className="opal-content-md-body">
<div className="opal-content-md-title-row">
{editing ? (
<div className="opal-content-md-input-sizer">
<span
className={cn("opal-content-md-input-mirror", config.titleFont)}
>
{editValue || "\u00A0"}
</span>
<input
ref={inputRef}
className={cn(
"opal-content-md-input",
config.titleFont,
"text-text-04"
)}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
size={1}
autoFocus
onFocus={(e) => e.currentTarget.select()}
onBlur={commit}
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditing(false);
}
}}
style={{ height: config.lineHeight }}
/>
</div>
) : (
<span
className={cn(
"opal-content-md-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
</span>
)}
{optional && (
<span
className={cn(config.optionalFont, "text-text-03 shrink-0")}
style={{ height: config.lineHeight }}
>
(Optional)
</span>
)}
{auxIcon &&
(() => {
const { icon: AuxIcon, colorClass } = AUX_ICON_CONFIG[auxIcon];
return (
<div
className="opal-content-md-aux-icon shrink-0 p-0.5"
style={{ height: config.lineHeight }}
>
<AuxIcon
className={colorClass}
style={{
width: config.auxIconSize,
height: config.auxIconSize,
}}
/>
</div>
);
})()}
{tag && <Tag {...tag} />}
{editable && !editing && (
<div
className={cn(
"opal-content-md-edit-button",
config.editButtonPadding
)}
>
<Button
icon={SvgEdit}
prominence="internal"
size={config.editButtonSize}
tooltip="Edit"
tooltipSide="right"
onClick={startEditing}
/>
</div>
)}
</div>
{description && (
<div className="opal-content-md-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
</div>
);
}
export {
ContentMd,
type ContentMdProps,
type ContentMdSizePreset,
type ContentMdAuxIcon,
};

View File

@@ -1,129 +0,0 @@
"use client";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type ContentSmSizePreset = "main-content" | "main-ui" | "secondary";
type ContentSmOrientation = "vertical" | "inline" | "reverse";
type ContentSmProminence = "default" | "muted";
interface ContentSmPresetConfig {
/** Icon width/height (CSS value). */
iconSize: string;
/** Tailwind padding class for the icon container. */
iconContainerPadding: string;
/** Tailwind font class for the title. */
titleFont: string;
/** Title line-height — also used as icon container min-height (CSS value). */
lineHeight: string;
/** Gap between icon container and title (CSS value). */
gap: string;
}
/** Props for {@link ContentSm}. Does not support editing or descriptions. */
interface ContentSmProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text (read-only — editing is not supported). */
title: string;
/** Size preset. Default: `"main-ui"`. */
sizePreset?: ContentSmSizePreset;
/** Layout orientation. Default: `"inline"`. */
orientation?: ContentSmOrientation;
/** Title prominence. Default: `"default"`. */
prominence?: ContentSmProminence;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const CONTENT_SM_PRESETS: Record<ContentSmSizePreset, ContentSmPresetConfig> = {
"main-content": {
iconSize: "1rem",
iconContainerPadding: "p-1",
titleFont: "font-main-content-body",
lineHeight: "1.5rem",
gap: "0.125rem",
},
"main-ui": {
iconSize: "1rem",
iconContainerPadding: "p-0.5",
titleFont: "font-main-ui-action",
lineHeight: "1.25rem",
gap: "0.25rem",
},
secondary: {
iconSize: "0.75rem",
iconContainerPadding: "p-0.5",
titleFont: "font-secondary-action",
lineHeight: "1rem",
gap: "0.125rem",
},
};
// ---------------------------------------------------------------------------
// ContentSm
// ---------------------------------------------------------------------------
function ContentSm({
icon: Icon,
title,
sizePreset = "main-ui",
orientation = "inline",
prominence = "default",
}: ContentSmProps) {
const config = CONTENT_SM_PRESETS[sizePreset];
const titleColorClass =
prominence === "muted" ? "text-text-03" : "text-text-04";
return (
<div
className="opal-content-sm"
data-orientation={orientation}
style={{ gap: config.gap }}
>
{Icon && (
<div
className={cn(
"opal-content-sm-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className="opal-content-sm-icon text-text-03"
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<span
className={cn(
"opal-content-sm-title",
config.titleFont,
titleColorClass
)}
style={{ height: config.lineHeight }}
>
{title}
</span>
</div>
);
}
export {
ContentSm,
type ContentSmProps,
type ContentSmSizePreset,
type ContentSmOrientation,
type ContentSmProminence,
};

View File

@@ -1,258 +0,0 @@
"use client";
import { Button } from "@opal/components/buttons/Button/components";
import type { SizeVariant } from "@opal/shared";
import SvgEdit from "@opal/icons/edit";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type ContentXlSizePreset = "headline" | "section";
interface ContentXlPresetConfig {
/** Icon width/height (CSS value). */
iconSize: string;
/** Tailwind padding class for the icon container. */
iconContainerPadding: string;
/** More-icon-1 width/height (CSS value). */
moreIcon1Size: string;
/** Tailwind padding class for the more-icon-1 container. */
moreIcon1ContainerPadding: string;
/** More-icon-2 width/height (CSS value). */
moreIcon2Size: string;
/** Tailwind padding class for the more-icon-2 container. */
moreIcon2ContainerPadding: string;
/** Tailwind font class for the title. */
titleFont: string;
/** Title line-height — also used as icon container min-height (CSS value). */
lineHeight: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: SizeVariant;
/** Tailwind padding class for the edit button container. */
editButtonPadding: string;
}
interface ContentXlProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
/** Optional description below the title. */
description?: string;
/** Enable inline editing of the title. */
editable?: boolean;
/** Called when the user commits an edit. */
onTitleChange?: (newTitle: string) => void;
/** Size preset. Default: `"headline"`. */
sizePreset?: ContentXlSizePreset;
/** Optional secondary icon rendered in the icon row. */
moreIcon1?: IconFunctionComponent;
/** Optional tertiary icon rendered in the icon row. */
moreIcon2?: IconFunctionComponent;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const CONTENT_XL_PRESETS: Record<ContentXlSizePreset, ContentXlPresetConfig> = {
headline: {
iconSize: "2rem",
iconContainerPadding: "p-0.5",
moreIcon1Size: "1rem",
moreIcon1ContainerPadding: "p-0.5",
moreIcon2Size: "2rem",
moreIcon2ContainerPadding: "p-0.5",
titleFont: "font-heading-h2",
lineHeight: "2.25rem",
editButtonSize: "md",
editButtonPadding: "p-1",
},
section: {
iconSize: "1.5rem",
iconContainerPadding: "p-0.5",
moreIcon1Size: "0.75rem",
moreIcon1ContainerPadding: "p-0.5",
moreIcon2Size: "1.5rem",
moreIcon2ContainerPadding: "p-0.5",
titleFont: "font-heading-h3",
lineHeight: "1.75rem",
editButtonSize: "sm",
editButtonPadding: "p-0.5",
},
};
// ---------------------------------------------------------------------------
// ContentXl
// ---------------------------------------------------------------------------
function ContentXl({
sizePreset = "headline",
icon: Icon,
title,
description,
editable,
onTitleChange,
moreIcon1: MoreIcon1,
moreIcon2: MoreIcon2,
}: ContentXlProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const inputRef = useRef<HTMLInputElement>(null);
const config = CONTENT_XL_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
setEditing(false);
}
return (
<div className="opal-content-xl">
{(Icon || MoreIcon1 || MoreIcon2) && (
<div className="opal-content-xl-icon-row">
{Icon && (
<div
className={cn(
"opal-content-xl-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className="opal-content-xl-icon"
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
{MoreIcon1 && (
<div
className={cn(
"opal-content-xl-more-icon-container shrink-0",
config.moreIcon1ContainerPadding
)}
>
<MoreIcon1
className="opal-content-xl-icon"
style={{
width: config.moreIcon1Size,
height: config.moreIcon1Size,
}}
/>
</div>
)}
{MoreIcon2 && (
<div
className={cn(
"opal-content-xl-more-icon-container shrink-0",
config.moreIcon2ContainerPadding
)}
>
<MoreIcon2
className="opal-content-xl-icon"
style={{
width: config.moreIcon2Size,
height: config.moreIcon2Size,
}}
/>
</div>
)}
</div>
)}
<div className="opal-content-xl-body">
<div className="opal-content-xl-title-row">
{editing ? (
<div className="opal-content-xl-input-sizer">
<span
className={cn("opal-content-xl-input-mirror", config.titleFont)}
>
{editValue || "\u00A0"}
</span>
<input
ref={inputRef}
className={cn(
"opal-content-xl-input",
config.titleFont,
"text-text-04"
)}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
size={1}
autoFocus
onFocus={(e) => e.currentTarget.select()}
onBlur={commit}
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditing(false);
}
}}
style={{ height: config.lineHeight }}
/>
</div>
) : (
<span
className={cn(
"opal-content-xl-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
</span>
)}
{editable && !editing && (
<div
className={cn(
"opal-content-xl-edit-button",
config.editButtonPadding
)}
>
<Button
icon={SvgEdit}
prominence="internal"
size={config.editButtonSize}
tooltip="Edit"
tooltipSide="right"
onClick={startEditing}
/>
</div>
)}
</div>
{description && (
<div className="opal-content-xl-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
</div>
);
}
export { ContentXl, type ContentXlProps, type ContentXlSizePreset };

View File

@@ -8,21 +8,14 @@ A two-axis layout component for displaying icon + title + description rows. Rout
### `sizePreset` — controls sizing (icon, padding, gap, font)
#### ContentXl presets (variant="heading")
| Preset | Icon | Icon padding | moreIcon1 | mI1 padding | moreIcon2 | mI2 padding | Title font | Line-height |
|---|---|---|---|---|---|---|---|---|
| `headline` | 2rem (32px) | `p-0.5` (2px) | 1rem (16px) | `p-0.5` (2px) | 2rem (32px) | `p-0.5` (2px) | `font-heading-h2` | 2.25rem (36px) |
| `section` | 1.5rem (24px) | `p-0.5` (2px) | 0.75rem (12px) | `p-0.5` (2px) | 1.5rem (24px) | `p-0.5` (2px) | `font-heading-h3` | 1.75rem (28px) |
#### ContentLg presets (variant="section")
#### HeadingLayout presets
| Preset | Icon | Icon padding | Gap | Title font | Line-height |
|---|---|---|---|---|---|
| `headline` | 2rem (32px) | `p-0.5` (2px) | 0.25rem (4px) | `font-heading-h2` | 2.25rem (36px) |
| `section` | 1.25rem (20px) | `p-1` (4px) | 0rem | `font-heading-h3-muted` | 1.75rem (28px) |
| `section` | 1.25rem (20px) | `p-1` (4px) | 0rem | `font-heading-h3` | 1.75rem (28px) |
#### ContentMd presets
#### LabelLayout presets
| Preset | Icon | Icon padding | Icon color | Gap | Title font | Line-height |
|---|---|---|---|---|---|---|
@@ -36,18 +29,18 @@ A two-axis layout component for displaying icon + title + description rows. Rout
| variant | Description |
|---|---|
| `heading` | Icon on **top** (flex-col) — ContentXl |
| `section` | Icon **inline** (flex-row) — ContentLg or ContentMd |
| `body` | Body text layout — ContentSm |
| `heading` | Icon on **top** (flex-col) — HeadingLayout |
| `section` | Icon **inline** (flex-row) — HeadingLayout or LabelLayout |
| `body` | Body text layout — BodyLayout (future) |
### Valid Combinations -> Internal Routing
| sizePreset | variant | Routes to |
|---|---|---|
| `headline` / `section` | `heading` | **ContentXl** (icon on top) |
| `headline` / `section` | `section` | **ContentLg** (icon inline) |
| `main-content` / `main-ui` / `secondary` | `section` | **ContentMd** |
| `main-content` / `main-ui` / `secondary` | `body` | **ContentSm** |
| `headline` / `section` | `heading` | **HeadingLayout** (icon on top) |
| `headline` / `section` | `section` | **HeadingLayout** (icon inline) |
| `main-content` / `main-ui` / `secondary` | `section` | **LabelLayout** |
| `main-content` / `main-ui` / `secondary` | `body` | BodyLayout (future) |
Invalid combinations (e.g. `sizePreset="headline" + variant="body"`) are excluded at the type level.
@@ -62,20 +55,14 @@ Invalid combinations (e.g. `sizePreset="headline" + variant="body"`) are exclude
| `description` | `string` | — | Optional description below the title |
| `editable` | `boolean` | `false` | Enable inline editing of the title |
| `onTitleChange` | `(newTitle: string) => void` | — | Called when user commits an edit |
| `moreIcon1` | `IconFunctionComponent` | — | Secondary icon in icon row (ContentXl only) |
| `moreIcon2` | `IconFunctionComponent` | — | Tertiary icon in icon row (ContentXl only) |
## Internal Layouts
### ContentXl
### HeadingLayout
For `headline` / `section` presets with `variant="heading"`. Icon row on top (flex-col), supports `moreIcon1` and `moreIcon2` in the icon row. Description is always `font-secondary-body text-text-03`.
For `headline` / `section` presets. Supports `variant="heading"` (icon on top) and `variant="section"` (icon inline). Description is always `font-secondary-body text-text-03`.
### ContentLg
For `headline` / `section` presets with `variant="section"`. Always inline (flex-row). Description is always `font-secondary-body text-text-03`.
### ContentMd
### LabelLayout
For `main-content` / `main-ui` / `secondary` presets. Always inline. Both `icon` and `description` are optional. Description is always `font-secondary-body text-text-03`.
@@ -85,7 +72,7 @@ For `main-content` / `main-ui` / `secondary` presets. Always inline. Both `icon`
import { Content } from "@opal/layouts";
import SvgSearch from "@opal/icons/search";
// ContentXl — headline, icon on top
// HeadingLayout — headline, icon on top
<Content
icon={SvgSearch}
sizePreset="headline"
@@ -94,17 +81,7 @@ import SvgSearch from "@opal/icons/search";
description="Configure your agent's behavior"
/>
// ContentXl — with more icons
<Content
icon={SvgSearch}
sizePreset="headline"
variant="heading"
title="Agent Settings"
moreIcon1={SvgStar}
moreIcon2={SvgLock}
/>
// ContentLg — section, icon inline
// HeadingLayout — section, icon inline
<Content
icon={SvgSearch}
sizePreset="section"
@@ -113,7 +90,7 @@ import SvgSearch from "@opal/icons/search";
description="Connected integrations"
/>
// ContentMd — with icon and description
// LabelLayout — with icon and description
<Content
icon={SvgSearch}
sizePreset="main-ui"
@@ -121,7 +98,7 @@ import SvgSearch from "@opal/icons/search";
description="Agent system prompt"
/>
// ContentMd — title only (no icon, no description)
// LabelLayout — title only (no icon, no description)
<Content
sizePreset="main-content"
title="Featured Agent"

View File

@@ -1,24 +1,21 @@
import "@opal/layouts/Content/styles.css";
import {
ContentSm,
type ContentSmOrientation,
type ContentSmProminence,
} from "@opal/layouts/Content/ContentSm";
BodyLayout,
type BodyOrientation,
type BodyProminence,
} from "@opal/layouts/Content/BodyLayout";
import {
ContentXl,
type ContentXlProps,
} from "@opal/layouts/Content/ContentXl";
HeadingLayout,
type HeadingLayoutProps,
} from "@opal/layouts/Content/HeadingLayout";
import {
ContentLg,
type ContentLgProps,
} from "@opal/layouts/Content/ContentLg";
import {
ContentMd,
type ContentMdProps,
} from "@opal/layouts/Content/ContentMd";
LabelLayout,
type LabelLayoutProps,
} from "@opal/layouts/Content/LabelLayout";
import type { TagProps } from "@opal/components/Tag/components";
import type { IconFunctionComponent } from "@opal/types";
import { widthVariants, type WidthVariant } from "@opal/shared";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
// Shared types
@@ -65,25 +62,14 @@ interface ContentBaseProps {
// Discriminated union: valid sizePreset × variant combinations
// ---------------------------------------------------------------------------
type XlContentProps = ContentBaseProps & {
type HeadingContentProps = ContentBaseProps & {
/** Size preset. Default: `"headline"`. */
sizePreset?: "headline" | "section";
/** Variant. Default: `"heading"` for heading-eligible presets. */
variant?: "heading";
/** Optional secondary icon rendered in the icon row (ContentXl only). */
moreIcon1?: IconFunctionComponent;
/** Optional tertiary icon rendered in the icon row (ContentXl only). */
moreIcon2?: IconFunctionComponent;
variant?: "heading" | "section";
};
type LgContentProps = ContentBaseProps & {
/** Size preset. Default: `"headline"`. */
sizePreset?: "headline" | "section";
/** Variant. */
variant: "section";
};
type MdContentProps = ContentBaseProps & {
type LabelContentProps = ContentBaseProps & {
sizePreset: "main-content" | "main-ui" | "secondary";
variant?: "section";
/** When `true`, renders "(Optional)" beside the title in the muted font variant. */
@@ -94,24 +80,20 @@ type MdContentProps = ContentBaseProps & {
tag?: TagProps;
};
/** ContentSm does not support descriptions or inline editing. */
type SmContentProps = Omit<
/** BodyLayout does not support descriptions or inline editing. */
type BodyContentProps = Omit<
ContentBaseProps,
"description" | "editable" | "onTitleChange"
> & {
sizePreset: "main-content" | "main-ui" | "secondary";
variant: "body";
/** Layout orientation. Default: `"inline"`. */
orientation?: ContentSmOrientation;
orientation?: BodyOrientation;
/** Title prominence. Default: `"default"`. */
prominence?: ContentSmProminence;
prominence?: BodyProminence;
};
type ContentProps =
| XlContentProps
| LgContentProps
| MdContentProps
| SmContentProps;
type ContentProps = HeadingContentProps | LabelContentProps | BodyContentProps;
// ---------------------------------------------------------------------------
// Content — routes to the appropriate internal layout
@@ -129,42 +111,34 @@ function Content(props: ContentProps) {
let layout: React.ReactNode = null;
// ContentXl / ContentLg: headline/section presets
// Heading layout: headline/section presets with heading/section variant
if (sizePreset === "headline" || sizePreset === "section") {
if (variant === "heading") {
layout = (
<ContentXl
sizePreset={sizePreset}
{...(rest as Omit<ContentXlProps, "sizePreset">)}
/>
);
} else {
layout = (
<ContentLg
sizePreset={sizePreset}
{...(rest as Omit<ContentLgProps, "sizePreset">)}
/>
);
}
}
// ContentMd: main-content/main-ui/secondary with section variant
else if (variant === "section" || variant === "heading") {
layout = (
<ContentMd
<HeadingLayout
sizePreset={sizePreset}
{...(rest as Omit<ContentMdProps, "sizePreset">)}
variant={variant as HeadingLayoutProps["variant"]}
{...rest}
/>
);
}
// ContentSm: main-content/main-ui/secondary with body variant
// Label layout: main-content/main-ui/secondary with section variant
else if (variant === "section" || variant === "heading") {
layout = (
<LabelLayout
sizePreset={sizePreset}
{...(rest as Omit<LabelLayoutProps, "sizePreset">)}
/>
);
}
// Body layout: main-content/main-ui/secondary with body variant
else if (variant === "body") {
layout = (
<ContentSm
<BodyLayout
sizePreset={sizePreset}
{...(rest as Omit<
React.ComponentProps<typeof ContentSm>,
React.ComponentProps<typeof BodyLayout>,
"sizePreset"
>)}
/>
@@ -193,8 +167,7 @@ export {
type ContentProps,
type SizePreset,
type ContentVariant,
type XlContentProps,
type LgContentProps,
type MdContentProps,
type SmContentProps,
type HeadingContentProps,
type LabelContentProps,
type BodyContentProps,
};

View File

@@ -1,145 +1,41 @@
/* ===========================================================================
Content — ContentXl
/* ---------------------------------------------------------------------------
Content — HeadingLayout
Icon row on top (flex-col). Icon row contains main icon + optional
moreIcon1 / moreIcon2 in a flex-row.
Two icon placement modes (driven by variant):
left (variant="section") : flex-row — icon beside content
top (variant="heading") : flex-col — icon above content
Sizing (icon size, gap, padding, font, line-height) is driven by the
sizePreset prop via inline styles + Tailwind classes in the component.
=========================================================================== */
/* ---------------------------------------------------------------------------
Layout — flex-col (icon row above body)
--------------------------------------------------------------------------- */
.opal-content-xl {
@apply flex flex-col items-start;
}
/* ---------------------------------------------------------------------------
Icon row — flex-row containing main icon + more icons
Layout — icon placement
--------------------------------------------------------------------------- */
.opal-content-xl-icon-row {
@apply flex flex-row items-center;
.opal-content-heading {
@apply flex items-start;
}
/* ---------------------------------------------------------------------------
Icons
--------------------------------------------------------------------------- */
.opal-content-xl-icon-container {
display: flex;
align-items: center;
justify-content: center;
.opal-content-heading[data-icon-placement="left"] {
@apply flex-row;
}
.opal-content-xl-more-icon-container {
display: flex;
align-items: center;
justify-content: center;
}
.opal-content-xl-icon {
color: var(--text-04);
}
/* ---------------------------------------------------------------------------
Body column
--------------------------------------------------------------------------- */
.opal-content-xl-body {
@apply flex flex-1 flex-col items-start;
min-width: 0.0625rem;
}
/* ---------------------------------------------------------------------------
Title row — title (or input) + edit button
--------------------------------------------------------------------------- */
.opal-content-xl-title-row {
@apply flex items-center w-full;
gap: 0.25rem;
}
.opal-content-xl-title {
@apply text-left overflow-hidden;
display: -webkit-box;
-webkit-box-orient: vertical;
-webkit-line-clamp: 1;
padding: 0 0.125rem;
min-width: 0.0625rem;
}
.opal-content-xl-input-sizer {
display: inline-grid;
align-items: stretch;
}
.opal-content-xl-input-sizer > * {
grid-area: 1 / 1;
padding: 0 0.125rem;
min-width: 0.0625rem;
}
.opal-content-xl-input-mirror {
visibility: hidden;
white-space: pre;
}
.opal-content-xl-input {
@apply bg-transparent outline-none border-none;
}
/* ---------------------------------------------------------------------------
Edit button — visible only on hover of the outer container
--------------------------------------------------------------------------- */
.opal-content-xl-edit-button {
@apply opacity-0 transition-opacity shrink-0;
}
.opal-content-xl:hover .opal-content-xl-edit-button {
@apply opacity-100;
}
/* ---------------------------------------------------------------------------
Description
--------------------------------------------------------------------------- */
.opal-content-xl-description {
@apply text-left w-full;
padding: 0 0.125rem;
}
/* ===========================================================================
Content — ContentLg
Always inline (flex-row) — icon beside content.
Sizing (icon size, gap, padding, font, line-height) is driven by the
sizePreset prop via inline styles + Tailwind classes in the component.
=========================================================================== */
/* ---------------------------------------------------------------------------
Layout
--------------------------------------------------------------------------- */
.opal-content-lg {
@apply flex flex-row items-start;
.opal-content-heading[data-icon-placement="top"] {
@apply flex-col;
}
/* ---------------------------------------------------------------------------
Icon
--------------------------------------------------------------------------- */
.opal-content-lg-icon-container {
.opal-content-heading-icon-container {
display: flex;
align-items: center;
justify-content: center;
}
.opal-content-lg-icon {
.opal-content-heading-icon {
color: var(--text-04);
}
@@ -147,7 +43,7 @@
Body column
--------------------------------------------------------------------------- */
.opal-content-lg-body {
.opal-content-heading-body {
@apply flex flex-1 flex-col items-start;
min-width: 0.0625rem;
}
@@ -156,12 +52,12 @@
Title row — title (or input) + edit button
--------------------------------------------------------------------------- */
.opal-content-lg-title-row {
.opal-content-heading-title-row {
@apply flex items-center w-full;
gap: 0.25rem;
}
.opal-content-lg-title {
.opal-content-heading-title {
@apply text-left overflow-hidden;
display: -webkit-box;
-webkit-box-orient: vertical;
@@ -170,23 +66,23 @@
min-width: 0.0625rem;
}
.opal-content-lg-input-sizer {
.opal-content-heading-input-sizer {
display: inline-grid;
align-items: stretch;
}
.opal-content-lg-input-sizer > * {
.opal-content-heading-input-sizer > * {
grid-area: 1 / 1;
padding: 0 0.125rem;
min-width: 0.0625rem;
}
.opal-content-lg-input-mirror {
.opal-content-heading-input-mirror {
visibility: hidden;
white-space: pre;
}
.opal-content-lg-input {
.opal-content-heading-input {
@apply bg-transparent outline-none border-none;
}
@@ -194,11 +90,11 @@
Edit button — visible only on hover of the outer container
--------------------------------------------------------------------------- */
.opal-content-lg-edit-button {
.opal-content-heading-edit-button {
@apply opacity-0 transition-opacity shrink-0;
}
.opal-content-lg:hover .opal-content-lg-edit-button {
.opal-content-heading:hover .opal-content-heading-edit-button {
@apply opacity-100;
}
@@ -206,13 +102,13 @@
Description
--------------------------------------------------------------------------- */
.opal-content-lg-description {
.opal-content-heading-description {
@apply text-left w-full;
padding: 0 0.125rem;
}
/* ===========================================================================
Content — ContentMd
Content — LabelLayout
Always inline (flex-row). Icon color varies per sizePreset and is applied
via Tailwind class from the component.
@@ -222,7 +118,7 @@
Layout
--------------------------------------------------------------------------- */
.opal-content-md {
.opal-content-label {
@apply flex flex-row items-start;
}
@@ -230,7 +126,7 @@
Icon
--------------------------------------------------------------------------- */
.opal-content-md-icon-container {
.opal-content-label-icon-container {
display: flex;
align-items: center;
justify-content: center;
@@ -240,7 +136,7 @@
Body column
--------------------------------------------------------------------------- */
.opal-content-md-body {
.opal-content-label-body {
@apply flex flex-1 flex-col items-start;
min-width: 0.0625rem;
}
@@ -249,12 +145,12 @@
Title row — title (or input) + edit button
--------------------------------------------------------------------------- */
.opal-content-md-title-row {
.opal-content-label-title-row {
@apply flex items-center w-full;
gap: 0.25rem;
}
.opal-content-md-title {
.opal-content-label-title {
@apply text-left overflow-hidden;
display: -webkit-box;
-webkit-box-orient: vertical;
@@ -263,23 +159,23 @@
min-width: 0.0625rem;
}
.opal-content-md-input-sizer {
.opal-content-label-input-sizer {
display: inline-grid;
align-items: stretch;
}
.opal-content-md-input-sizer > * {
.opal-content-label-input-sizer > * {
grid-area: 1 / 1;
padding: 0 0.125rem;
min-width: 0.0625rem;
}
.opal-content-md-input-mirror {
.opal-content-label-input-mirror {
visibility: hidden;
white-space: pre;
}
.opal-content-md-input {
.opal-content-label-input {
@apply bg-transparent outline-none border-none;
}
@@ -287,7 +183,7 @@
Aux icon
--------------------------------------------------------------------------- */
.opal-content-md-aux-icon {
.opal-content-label-aux-icon {
display: flex;
align-items: center;
justify-content: center;
@@ -297,11 +193,11 @@
Edit button — visible only on hover of the outer container
--------------------------------------------------------------------------- */
.opal-content-md-edit-button {
.opal-content-label-edit-button {
@apply opacity-0 transition-opacity shrink-0;
}
.opal-content-md:hover .opal-content-md-edit-button {
.opal-content-label:hover .opal-content-label-edit-button {
@apply opacity-100;
}
@@ -309,13 +205,13 @@
Description
--------------------------------------------------------------------------- */
.opal-content-md-description {
.opal-content-label-description {
@apply text-left w-full;
padding: 0 0.125rem;
}
/* ===========================================================================
Content — ContentSm
Content — BodyLayout
Three orientation modes (driven by orientation prop):
inline : flex-row — icon left, title right
@@ -330,19 +226,19 @@
Layout — orientation
--------------------------------------------------------------------------- */
.opal-content-sm {
.opal-content-body {
@apply flex items-start;
}
.opal-content-sm[data-orientation="inline"] {
.opal-content-body[data-orientation="inline"] {
@apply flex-row;
}
.opal-content-sm[data-orientation="vertical"] {
.opal-content-body[data-orientation="vertical"] {
@apply flex-col;
}
.opal-content-sm[data-orientation="reverse"] {
.opal-content-body[data-orientation="reverse"] {
@apply flex-row-reverse;
}
@@ -350,7 +246,7 @@
Icon
--------------------------------------------------------------------------- */
.opal-content-sm-icon-container {
.opal-content-body-icon-container {
display: flex;
align-items: center;
justify-content: center;
@@ -360,7 +256,7 @@
Title
--------------------------------------------------------------------------- */
.opal-content-sm-title {
.opal-content-body-title {
@apply text-left overflow-hidden;
display: -webkit-box;
-webkit-box-orient: vertical;

View File

@@ -8,7 +8,7 @@ Layout primitives for composing icon + title + description rows. These component
| Component | Description | Docs |
|---|---|---|
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`ContentLg`, `ContentMd`, or `ContentSm`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`HeadingLayout`, `LabelLayout`, or `BodyLayout`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
| [`ContentAction`](./ContentAction/README.md) | Wraps `Content` in a flex-row with an optional `rightChildren` slot for action buttons. Adds padding alignment via the shared `SizeVariant` scale. | [ContentAction README](./ContentAction/README.md) |
## Quick Start
@@ -88,7 +88,6 @@ These are not exported — `Content` routes to them automatically:
| Layout | Used when | File |
|---|---|---|
| `ContentXl` | `sizePreset` is `headline` or `section` with `variant="heading"` | `Content/ContentXl.tsx` |
| `ContentLg` | `sizePreset` is `headline` or `section` with `variant="section"` | `Content/ContentLg.tsx` |
| `ContentMd` | `sizePreset` is `main-content`, `main-ui`, or `secondary` with `variant="section"` | `Content/ContentMd.tsx` |
| `ContentSm` | `variant="body"` | `Content/ContentSm.tsx` |
| `HeadingLayout` | `sizePreset` is `headline` or `section` | `Content/HeadingLayout.tsx` |
| `LabelLayout` | `sizePreset` is `main-content`, `main-ui`, or `secondary` with `variant="section"` | `Content/LabelLayout.tsx` |
| `BodyLayout` | `variant="body"` | `Content/BodyLayout.tsx` |

View File

@@ -16,7 +16,7 @@
// - Interactive.Container (height + min-width + padding)
// - Button (icon sizing)
// - ContentAction (padding only)
// - Content (ContentLg / ContentMd) (edit-button size)
// - Content (HeadingLayout / LabelLayout) (edit-button size)
// ---------------------------------------------------------------------------
/**

View File

@@ -2,7 +2,7 @@
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
import { LLMProviderView } from "@/interfaces/llm";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
import { getImageGenForm } from "./forms";

View File

@@ -7,7 +7,7 @@ import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import {
IMAGE_PROVIDER_GROUPS,
ImageProvider,
@@ -23,14 +23,13 @@ import Message from "@/refresh-components/messages/Message";
export default function ImageGenerationContent() {
const {
data: llmProviderResponse,
data: llmProviders = [],
error: llmError,
mutate: refetchProviders,
} = useSWR<LLMProviderResponse<LLMProviderView>>(
} = useSWR<LLMProviderView[]>(
"/api/admin/llm/provider?include_image_gen=true",
errorHandlingFetcher
);
const llmProviders = llmProviderResponse?.providers ?? [];
const {
data: configs = [],

View File

@@ -1,6 +1,6 @@
import { FormikProps } from "formik";
import { ImageProvider } from "../constants";
import { LLMProviderView } from "@/interfaces/llm";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import {
ImageGenerationConfigView,
ImageGenerationCredentials,

View File

@@ -0,0 +1,84 @@
"use client";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR from "swr";
import { Callout } from "@/components/ui/callout";
import Text from "@/refresh-components/texts/Text";
import Title from "@/components/ui/title";
import { ThreeDotsLoader } from "@/components/Loading";
import { LLMProviderView } from "./interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { OpenAIForm } from "./forms/OpenAIForm";
import { AnthropicForm } from "./forms/AnthropicForm";
import { OllamaForm } from "./forms/OllamaForm";
import { AzureForm } from "./forms/AzureForm";
import { BedrockForm } from "./forms/BedrockForm";
import { VertexAIForm } from "./forms/VertexAIForm";
import { OpenRouterForm } from "./forms/OpenRouterForm";
import { getFormForExistingProvider } from "./forms/getForm";
import { CustomForm } from "./forms/CustomForm";
export function LLMConfiguration() {
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);
if (!existingLlmProviders) {
return <ThreeDotsLoader />;
}
const isFirstProvider = existingLlmProviders.length === 0;
return (
<>
<Title className="mb-2">Enabled LLM Providers</Title>
{existingLlmProviders.length > 0 ? (
<>
<Text as="p" className="mb-4">
If multiple LLM providers are enabled, the default provider will be
used for all &quot;Default&quot; Assistants. For user-created
Assistants, you can select the LLM provider/model that best fits the
use case!
</Text>
<div className="flex flex-col gap-y-4">
{[...existingLlmProviders]
.sort((a, b) => {
if (a.is_default_provider && !b.is_default_provider) return -1;
if (!a.is_default_provider && b.is_default_provider) return 1;
return 0;
})
.map((llmProvider) => (
<div key={llmProvider.id}>
{getFormForExistingProvider(llmProvider)}
</div>
))}
</div>
</>
) : (
<Callout type="warning" title="No LLM providers configured yet">
Please set one up below in order to start using Onyx!
</Callout>
)}
<Title className="mb-2 mt-6">Add LLM Provider</Title>
<Text as="p" className="mb-4">
Add a new LLM provider by either selecting from one of the default
providers or by specifying your own custom LLM provider.
</Text>
<div className="flex flex-col gap-y-4">
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
<AzureForm shouldMarkAsDefault={isFirstProvider} />
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
<CustomForm shouldMarkAsDefault={isFirstProvider} />
</div>
</>
);
}

View File

@@ -1,14 +1,13 @@
"use client";
import { ArrayHelpers, FieldArray, FormikProps, useField } from "formik";
import { ModelConfiguration } from "@/interfaces/llm";
import { ModelConfiguration } from "./interfaces";
import { ManualErrorMessage, TextFormField } from "@/components/Field";
import { useEffect, useState } from "react";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import { SvgX } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
function ModelConfigurationRow({
name,
index,

View File

@@ -1,5 +1,4 @@
export const LLM_ADMIN_URL = "/api/admin/llm";
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
"/api/admin/llm/provider-contextual-cost";

View File

@@ -1,5 +1,5 @@
import { Form, Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -21,19 +21,15 @@ import { DisplayModels } from "./components/DisplayModels";
export const ANTHROPIC_PROVIDER_NAME = "anthropic";
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
export function AnthropicModal({
export function AnthropicForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName="Anthropic"
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -56,6 +52,7 @@ export function AnthropicModal({
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? undefined,
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
// Default to auto mode for new Anthropic providers

View File

@@ -1,6 +1,6 @@
import { Form, Formik } from "formik";
import { TextFormField } from "@/components/Field";
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -28,7 +28,7 @@ import Separator from "@/refresh-components/Separator";
export const AZURE_PROVIDER_NAME = "azure";
const AZURE_DISPLAY_NAME = "Microsoft Azure Cloud";
interface AzureModalValues extends BaseLLMFormValues {
interface AzureFormValues extends BaseLLMFormValues {
api_key: string;
target_uri: string;
api_base?: string;
@@ -47,19 +47,15 @@ const buildTargetUri = (existingLlmProvider?: LLMProviderView): string => {
return `${existingLlmProvider.api_base}/openai/deployments/${deploymentName}/chat/completions?api-version=${existingLlmProvider.api_version}`;
};
export function AzureModal({
export function AzureForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName={AZURE_DISPLAY_NAME}
providerEndpoint={AZURE_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -74,7 +70,7 @@ export function AzureModal({
existingLlmProvider,
wellKnownLLMProvider
);
const initialValues: AzureModalValues = {
const initialValues: AzureFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
@@ -101,7 +97,7 @@ export function AzureModal({
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
// Parse target_uri to extract api_base, api_version, and deployment_name
let processedValues: AzureModalValues = { ...values };
let processedValues: AzureFormValues = { ...values };
if (values.target_uri) {
try {

View File

@@ -8,7 +8,7 @@ import {
LLMProviderFormProps,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
} from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -27,7 +27,7 @@ import {
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
import { fetchBedrockModels } from "../utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import Tabs from "@/refresh-components/Tabs";
@@ -65,7 +65,7 @@ const FIELD_AWS_ACCESS_KEY_ID = "custom_config.AWS_ACCESS_KEY_ID";
const FIELD_AWS_SECRET_ACCESS_KEY = "custom_config.AWS_SECRET_ACCESS_KEY";
const FIELD_AWS_BEARER_TOKEN_BEDROCK = "custom_config.AWS_BEARER_TOKEN_BEDROCK";
interface BedrockModalValues extends BaseLLMFormValues {
interface BedrockFormValues extends BaseLLMFormValues {
custom_config: {
AWS_REGION_NAME: string;
BEDROCK_AUTH_METHOD?: string;
@@ -75,8 +75,8 @@ interface BedrockModalValues extends BaseLLMFormValues {
};
}
interface BedrockModalInternalsProps {
formikProps: FormikProps<BedrockModalValues>;
interface BedrockFormInternalsProps {
formikProps: FormikProps<BedrockFormValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
@@ -87,7 +87,7 @@ interface BedrockModalInternalsProps {
onClose: () => void;
}
function BedrockModalInternals({
function BedrockFormInternals({
formikProps,
existingLlmProvider,
fetchedModels,
@@ -97,7 +97,7 @@ function BedrockModalInternals({
testError,
mutate,
onClose,
}: BedrockModalInternalsProps) {
}: BedrockFormInternalsProps) {
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
// Clean up unused auth fields when tab changes
@@ -258,11 +258,9 @@ function BedrockModalInternals({
);
}
export function BedrockModal({
export function BedrockForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
@@ -270,8 +268,6 @@ export function BedrockModal({
<ProviderFormEntrypointWrapper
providerName={BEDROCK_DISPLAY_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -286,7 +282,7 @@ export function BedrockModal({
existingLlmProvider,
wellKnownLLMProvider
);
const initialValues: BedrockModalValues = {
const initialValues: BedrockFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
@@ -356,7 +352,7 @@ export function BedrockModal({
}}
>
{(formikProps) => (
<BedrockModalInternals
<BedrockFormInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}

View File

@@ -7,7 +7,7 @@
*/
import React from "react";
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
import { CustomModal } from "./CustomModal";
import { CustomForm } from "./CustomForm";
import { toast } from "@/hooks/useToast";
// Mock SWR's mutate function and useSWR
@@ -116,10 +116,11 @@ describe("Custom LLM Provider Configuration Workflow", () => {
name: "My Custom Provider",
provider: "openai",
api_key: "test-key",
default_model_name: "gpt-4",
}),
} as Response);
render(<CustomModal />);
render(<CustomForm />);
await openModalAndFillBasicFields(user, {
name: "My Custom Provider",
@@ -176,7 +177,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ detail: "Invalid API key" }),
} as Response);
render(<CustomModal />);
render(<CustomForm />);
await openModalAndFillBasicFields(user, {
name: "Bad Provider",
@@ -223,13 +224,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
api_key: "old-key",
api_base: "",
api_version: "",
default_model_name: "claude-3-opus",
model_configurations: [
{
name: "claude-3-opus",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
supports_image_input: null,
},
],
custom_config: {},
@@ -238,6 +239,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
groups: [],
personas: [],
deployment_name: null,
is_default_provider: false,
default_vision_model: null,
is_default_vision_provider: null,
};
// Mock POST /api/admin/llm/test
@@ -252,7 +256,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ ...existingProvider, api_key: "new-key" }),
} as Response);
render(<CustomModal existingLlmProvider={existingProvider} />);
render(<CustomForm existingLlmProvider={existingProvider} />);
// For existing provider, click "Edit" button to open modal
const editButton = screen.getByRole("button", { name: /edit/i });
@@ -303,13 +307,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
api_key: "old-key",
api_base: "https://example-openai-compatible.local/v1",
api_version: "",
default_model_name: "gpt-oss-20b-bw-failover",
model_configurations: [
{
name: "gpt-oss-20b-bw-failover",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
supports_image_input: null,
},
],
custom_config: {},
@@ -318,6 +322,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
groups: [],
personas: [],
deployment_name: null,
is_default_provider: false,
default_vision_model: null,
is_default_vision_provider: null,
};
// Mock POST /api/admin/llm/test
@@ -336,21 +343,19 @@ describe("Custom LLM Provider Configuration Workflow", () => {
name: "gpt-oss-20b-bw-failover",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
supports_image_input: null,
},
{
name: "nemotron",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
supports_image_input: null,
},
],
}),
} as Response);
render(<CustomModal existingLlmProvider={existingProvider} />);
render(<CustomForm existingLlmProvider={existingProvider} />);
const editButton = screen.getByRole("button", { name: /edit/i });
await user.click(editButton);
@@ -418,7 +423,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({}),
} as Response);
render(<CustomModal shouldMarkAsDefault={true} />);
render(<CustomForm shouldMarkAsDefault={true} />);
await openModalAndFillBasicFields(user, {
name: "New Default Provider",
@@ -458,7 +463,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ detail: "Database error" }),
} as Response);
render(<CustomModal />);
render(<CustomForm />);
await openModalAndFillBasicFields(user, {
name: "Test Provider",
@@ -494,7 +499,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
} as Response);
render(<CustomModal />);
render(<CustomForm />);
// Open modal
const openButton = screen.getByRole("button", {

View File

@@ -7,7 +7,7 @@ import {
Formik,
ErrorMessage,
} from "formik";
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
@@ -21,7 +21,7 @@ import {
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { TextFormField } from "@/components/Field";
import { ModelConfigurationField } from "@/app/admin/configuration/llm/ModelConfigurationField";
import { ModelConfigurationField } from "../ModelConfigurationField";
import Text from "@/refresh-components/texts/Text";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import IconButton from "@/refresh-components/buttons/IconButton";
@@ -38,11 +38,9 @@ function customConfigProcessing(customConfigsList: [string, string][]) {
return customConfig;
}
export function CustomModal({
export function CustomForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
@@ -50,8 +48,6 @@ export function CustomModal({
existingLlmProvider={existingLlmProvider}
buttonMode={!existingLlmProvider}
buttonText="Add Custom LLM Provider"
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -72,15 +68,7 @@ export function CustomModal({
...modelConfiguration,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
})
) ?? [
{
name: "",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
],
) ?? [{ name: "", is_visible: true, max_input_tokens: null }],
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],
@@ -124,8 +112,7 @@ export function CustomModal({
name: mc.name,
is_visible: mc.is_visible,
max_input_tokens: mc.max_input_tokens ?? null,
supports_image_input: mc.supports_image_input ?? false,
supports_reasoning: mc.supports_reasoning ?? false,
supports_image_input: null,
}))
.filter(
(mc) => mc.name === values.default_model_name || mc.is_visible

View File

@@ -5,7 +5,7 @@ import {
LLMProviderFormProps,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
} from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -24,20 +24,20 @@ import {
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { useEffect, useState } from "react";
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
import { fetchOllamaModels } from "../utils";
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
interface OllamaModalValues extends BaseLLMFormValues {
interface OllamaFormValues extends BaseLLMFormValues {
api_base: string;
custom_config: {
OLLAMA_API_KEY?: string;
};
}
interface OllamaModalContentProps {
formikProps: FormikProps<OllamaModalValues>;
interface OllamaFormContentProps {
formikProps: FormikProps<OllamaFormValues>;
existingLlmProvider?: LLMProviderView;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
@@ -48,7 +48,7 @@ interface OllamaModalContentProps {
isFormValid: boolean;
}
function OllamaModalContent({
function OllamaFormContent({
formikProps,
existingLlmProvider,
fetchedModels,
@@ -58,7 +58,7 @@ function OllamaModalContent({
mutate,
onClose,
isFormValid,
}: OllamaModalContentProps) {
}: OllamaFormContentProps) {
const [isLoadingModels, setIsLoadingModels] = useState(true);
useEffect(() => {
@@ -131,11 +131,9 @@ function OllamaModalContent({
);
}
export function OllamaModal({
export function OllamaForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
@@ -143,8 +141,6 @@ export function OllamaModal({
<ProviderFormEntrypointWrapper
providerName="Ollama"
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -159,7 +155,7 @@ export function OllamaModal({
existingLlmProvider,
wellKnownLLMProvider
);
const initialValues: OllamaModalValues = {
const initialValues: OllamaFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
@@ -216,7 +212,7 @@ export function OllamaModal({
}}
>
{(formikProps) => (
<OllamaModalContent
<OllamaFormContent
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}

View File

@@ -1,6 +1,6 @@
import { Form, Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
@@ -19,19 +19,15 @@ import { DisplayModels } from "./components/DisplayModels";
export const OPENAI_PROVIDER_NAME = "openai";
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
export function OpenAIModal({
export function OpenAIForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName="OpenAI"
providerEndpoint={OPENAI_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -53,6 +49,7 @@ export function OpenAIModal({
),
api_key: existingLlmProvider?.api_key ?? "",
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
// Default to auto mode for new OpenAI providers

View File

@@ -5,7 +5,7 @@ import {
LLMProviderFormProps,
ModelConfiguration,
OpenRouterModelResponse,
} from "@/interfaces/llm";
} from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -32,7 +32,7 @@ const OPENROUTER_DISPLAY_NAME = "OpenRouter";
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
const OPENROUTER_MODELS_API_URL = "/api/admin/llm/openrouter/available-models";
interface OpenRouterModalValues extends BaseLLMFormValues {
interface OpenRouterFormValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
@@ -80,7 +80,6 @@ async function fetchOpenRouterModels(params: {
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: false,
}));
return { models };
@@ -91,11 +90,9 @@ async function fetchOpenRouterModels(params: {
}
}
export function OpenRouterModal({
export function OpenRouterForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
@@ -104,8 +101,6 @@ export function OpenRouterModal({
providerName={OPENROUTER_DISPLAY_NAME}
providerEndpoint={OPENROUTER_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -120,7 +115,7 @@ export function OpenRouterModal({
existingLlmProvider,
wellKnownLLMProvider
);
const initialValues: OpenRouterModalValues = {
const initialValues: OpenRouterFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations

View File

@@ -1,6 +1,6 @@
import { Form, Formik } from "formik";
import { TextFormField, FileUploadFormField } from "@/components/Field";
import { LLMProviderFormProps } from "@/interfaces/llm";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import {
ProviderFormEntrypointWrapper,
@@ -25,26 +25,22 @@ const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
const VERTEXAI_DEFAULT_LOCATION = "global";
interface VertexAIModalValues extends BaseLLMFormValues {
interface VertexAIFormValues extends BaseLLMFormValues {
custom_config: {
vertex_credentials: string;
vertex_location: string;
};
}
export function VertexAIModal({
export function VertexAIForm({
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName={VERTEXAI_DISPLAY_NAME}
providerEndpoint={VERTEXAI_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
open={open}
onOpenChange={onOpenChange}
>
{({
onClose,
@@ -59,12 +55,13 @@ export function VertexAIModal({
existingLlmProvider,
wellKnownLLMProvider
);
const initialValues: VertexAIModalValues = {
const initialValues: VertexAIFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
),
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL,
// Default to auto mode for new Vertex AI providers

View File

@@ -1,4 +1,4 @@
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
import { ModelConfiguration, SimpleKnownModel } from "../../interfaces";
import { FormikProps } from "formik";
import { BaseLLMFormValues } from "../formUtils";

View File

@@ -2,7 +2,7 @@ import { useState, useEffect } from "react";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import { ModelConfiguration } from "@/interfaces/llm";
import { ModelConfiguration } from "../../interfaces";
interface FetchModelsButtonProps {
onFetch: () => Promise<{ models: ModelConfiguration[]; error?: string }>;

View File

@@ -2,9 +2,8 @@ import { LoadingAnimation } from "@/components/Loading";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { SvgTrash } from "@opal/icons";
import { LLMProviderView } from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import { deleteLlmProvider } from "@/lib/llmConfig/svc";
import { LLMProviderView } from "../../interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
interface FormActionButtonsProps {
isTesting: boolean;
@@ -26,14 +25,41 @@ export function FormActionButtons({
const handleDelete = async () => {
if (!existingLlmProvider) return;
try {
await deleteLlmProvider(existingLlmProvider.id);
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
} catch (e) {
const message = e instanceof Error ? e.message : "Unknown error";
alert(`Failed to delete provider: ${message}`);
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
// If the deleted provider was the default, set the first remaining provider as default
if (existingLlmProvider.is_default_provider) {
const remainingProvidersResponse = await fetch(LLM_PROVIDERS_ADMIN_URL);
if (remainingProvidersResponse.ok) {
const remainingProviders = await remainingProvidersResponse.json();
if (remainingProviders.length > 0) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
console.error("Failed to set new default provider");
}
}
}
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
};
return (

View File

@@ -6,7 +6,7 @@ import { toast } from "@/hooks/useToast";
import {
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
} from "../../interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import Modal from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
@@ -14,8 +14,7 @@ import Button from "@/refresh-components/buttons/Button";
import { SvgSettings } from "@opal/icons";
import { Badge } from "@/components/ui/badge";
import { cn } from "@/lib/utils";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import { setDefaultLlmModel } from "@/lib/llmConfig/svc";
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
export interface ProviderFormContext {
onClose: () => void;
@@ -36,10 +35,6 @@ interface ProviderFormEntrypointWrapperProps {
buttonMode?: boolean;
/** Custom button text for buttonMode (defaults to "Add {providerName}") */
buttonText?: string;
/** Controlled open state — when defined, the wrapper renders only a modal (no card/button UI) */
open?: boolean;
/** Callback when controlled modal requests close */
onOpenChange?: (open: boolean) => void;
}
export function ProviderFormEntrypointWrapper({
@@ -49,11 +44,8 @@ export function ProviderFormEntrypointWrapper({
existingLlmProvider,
buttonMode,
buttonText,
open,
onOpenChange,
}: ProviderFormEntrypointWrapperProps) {
const [formIsVisible, setFormIsVisible] = useState(false);
const isControlled = open !== undefined;
// Shared hooks
const { mutate } = useSWRConfig();
@@ -62,45 +54,33 @@ export function ProviderFormEntrypointWrapper({
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
// Suppress SWR when controlled + closed to avoid unnecessary API calls
const swrKey =
providerEndpoint && !(isControlled && !open)
? `/api/admin/llm/built-in/options/${providerEndpoint}`
: null;
// Fetch model configurations for this provider
const { data: wellKnownLLMProvider } = useSWR<WellKnownLLMProviderDescriptor>(
swrKey,
providerEndpoint
? `/api/admin/llm/built-in/options/${providerEndpoint}`
: null,
errorHandlingFetcher
);
const onClose = () => {
if (isControlled) {
onOpenChange?.(false);
} else {
setFormIsVisible(false);
}
};
const onClose = () => setFormIsVisible(false);
async function handleSetAsDefault(): Promise<void> {
if (!existingLlmProvider) return;
const firstVisibleModel = existingLlmProvider.model_configurations.find(
(m) => m.is_visible
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!firstVisibleModel) {
toast.error("No visible models available for this provider.");
if (!response.ok) {
const errorMsg = (await response.json()).detail;
toast.error(`Failed to set provider as default: ${errorMsg}`);
return;
}
try {
await setDefaultLlmModel(existingLlmProvider.id, firstVisibleModel.name);
await mutate(LLM_PROVIDERS_ADMIN_URL);
toast.success("Provider set as default successfully!");
} catch (e) {
const message = e instanceof Error ? e.message : "Unknown error";
toast.error(`Failed to set provider as default: ${message}`);
}
await mutate(LLM_PROVIDERS_ADMIN_URL);
toast.success("Provider set as default successfully!");
}
const context: ProviderFormContext = {
@@ -113,31 +93,6 @@ export function ProviderFormEntrypointWrapper({
wellKnownLLMProvider,
};
const defaultTitle = `${existingLlmProvider ? "Configure" : "Setup"} ${
existingLlmProvider?.name ? `"${existingLlmProvider.name}"` : providerName
}`;
function renderModal(isVisible: boolean, title?: string) {
if (!isVisible) return null;
return (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={title ?? defaultTitle}
onClose={onClose}
/>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
);
}
// Controlled mode: render nothing when closed, render only modal when open
if (isControlled) {
return renderModal(!!open);
}
// Button mode: simple button that opens a modal
if (buttonMode && !existingLlmProvider) {
return (
@@ -145,7 +100,19 @@ export function ProviderFormEntrypointWrapper({
<Button action onClick={() => setFormIsVisible(true)}>
{buttonText ?? `Add ${providerName}`}
</Button>
{renderModal(formIsVisible, `Setup ${providerName}`)}
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`Setup ${providerName}`}
onClose={onClose}
/>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
)}
</>
);
}
@@ -168,18 +135,24 @@ export function ProviderFormEntrypointWrapper({
<Text as="p" secondaryBody text03 className="italic">
({providerName})
</Text>
<Text
as="p"
className={cn("text-action-link-05", "cursor-pointer")}
onClick={handleSetAsDefault}
>
Set as default
</Text>
{!existingLlmProvider.is_default_provider && (
<Text
as="p"
className={cn("text-action-link-05", "cursor-pointer")}
onClick={handleSetAsDefault}
>
Set as default
</Text>
)}
</div>
{existingLlmProvider && (
<div className="my-auto ml-3">
<Badge variant="success">Enabled</Badge>
{existingLlmProvider.is_default_provider ? (
<Badge variant="agent">Default</Badge>
) : (
<Badge variant="success">Enabled</Badge>
)}
</div>
)}
@@ -209,7 +182,22 @@ export function ProviderFormEntrypointWrapper({
)}
</div>
{renderModal(formIsVisible)}
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
existingLlmProvider?.name
? `"${existingLlmProvider.name}"`
: providerName
}`}
onClose={onClose}
/>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
)}
</div>
);
}

View File

@@ -2,11 +2,8 @@ import {
LLMProviderView,
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/lib/llmConfig/constants";
} from "../interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "../constants";
import { toast } from "@/hooks/useToast";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@@ -18,7 +15,10 @@ export const buildDefaultInitialValues = (
existingLlmProvider?: LLMProviderView,
modelConfigurations?: ModelConfiguration[]
) => {
const defaultModelName = modelConfigurations?.[0]?.name ?? "";
const defaultModelName =
existingLlmProvider?.default_model_name ??
modelConfigurations?.[0]?.name ??
"";
// Auto mode must be explicitly enabled by the user
// Default to false for new providers, preserve existing value when editing
@@ -119,7 +119,6 @@ export const filterModelConfigurations = (
is_visible: visibleModels.includes(modelConfiguration.name),
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
)
@@ -142,7 +141,6 @@ export const getAutoModeModelConfigurations = (
is_visible: modelConfiguration.is_visible,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
);
@@ -225,8 +223,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
body: JSON.stringify({
provider: providerName,
...finalValues,
model: finalDefaultModelName,
id: existingLlmProvider?.id,
}),
});
setIsTesting(false);
@@ -251,7 +247,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
body: JSON.stringify({
provider: providerName,
...finalValues,
id: existingLlmProvider?.id,
}),
}
);
@@ -267,16 +262,12 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: finalDefaultModelName,
}),
});
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;
toast.error(`Failed to set provider as default: ${errorMsg}`);

View File

@@ -0,0 +1,44 @@
import { LLMProviderName, LLMProviderView } from "../interfaces";
import { AnthropicForm } from "./AnthropicForm";
import { OpenAIForm } from "./OpenAIForm";
import { OllamaForm } from "./OllamaForm";
import { AzureForm } from "./AzureForm";
import { VertexAIForm } from "./VertexAIForm";
import { OpenRouterForm } from "./OpenRouterForm";
import { CustomForm } from "./CustomForm";
import { BedrockForm } from "./BedrockForm";
export function detectIfRealOpenAIProvider(provider: LLMProviderView) {
return (
provider.provider === LLMProviderName.OPENAI &&
provider.api_key &&
!provider.api_base &&
Object.keys(provider.custom_config || {}).length === 0
);
}
export const getFormForExistingProvider = (provider: LLMProviderView) => {
switch (provider.provider) {
case LLMProviderName.OPENAI:
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
if (detectIfRealOpenAIProvider(provider)) {
return <OpenAIForm existingLlmProvider={provider} />;
} else {
return <CustomForm existingLlmProvider={provider} />;
}
case LLMProviderName.ANTHROPIC:
return <AnthropicForm existingLlmProvider={provider} />;
case LLMProviderName.OLLAMA_CHAT:
return <OllamaForm existingLlmProvider={provider} />;
case LLMProviderName.AZURE:
return <AzureForm existingLlmProvider={provider} />;
case LLMProviderName.VERTEX_AI:
return <VertexAIForm existingLlmProvider={provider} />;
case LLMProviderName.BEDROCK:
return <BedrockForm existingLlmProvider={provider} />;
case LLMProviderName.OPENROUTER:
return <OpenRouterForm existingLlmProvider={provider} />;
default:
return <CustomForm existingLlmProvider={provider} />;
}
};

View File

@@ -13,8 +13,8 @@ export interface ModelConfiguration {
name: string;
is_visible: boolean;
max_input_tokens: number | null;
supports_image_input: boolean;
supports_reasoning: boolean;
supports_image_input: boolean | null;
supports_reasoning?: boolean;
display_name?: string;
provider_display_name?: string;
vendor?: string;
@@ -30,6 +30,7 @@ export interface SimpleKnownModel {
export interface WellKnownLLMProviderDescriptor {
name: string;
known_models: ModelConfiguration[];
recommended_default_model: SimpleKnownModel | null;
}
@@ -39,31 +40,44 @@ export interface LLMModelDescriptor {
maxTokens: number;
}
export interface LLMProviderView {
id: number;
export interface LLMProvider {
name: string;
provider: string;
api_key: string | null;
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;
default_model_name: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
deployment_name: string | null;
default_vision_model: string | null;
is_default_vision_provider: boolean | null;
model_configurations: ModelConfiguration[];
}
export interface LLMProviderView extends LLMProvider {
id: number;
is_default_provider: boolean | null;
}
export interface VisionProvider extends LLMProviderView {
vision_models: string[];
}
export interface LLMProviderDescriptor {
id: number;
name: string;
provider: string;
provider_display_name: string;
provider_display_name?: string;
default_model_name: string;
is_default_provider: boolean | null;
is_default_vision_provider?: boolean | null;
default_vision_model?: string | null;
is_public?: boolean;
groups?: number[];
personas?: number[];
model_configurations: ModelConfiguration[];
}
@@ -88,22 +102,9 @@ export interface BedrockModelResponse {
supports_image_input: boolean;
}
export interface DefaultModel {
provider_id: number;
model_name: string;
}
export interface LLMProviderResponse<T> {
providers: T[];
default_text: DefaultModel | null;
default_vision: DefaultModel | null;
}
export interface LLMProviderFormProps {
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
open?: boolean;
onOpenChange?: (open: boolean) => void;
}
// Param types for model fetching functions - use snake_case to match API structure

View File

@@ -1,7 +1,14 @@
"use client";
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
import { AdminPageTitle } from "@/components/admin/Title";
import { LLMConfiguration } from "./LLMConfiguration";
import { SvgCpu } from "@opal/icons";
export default function Page() {
return <LLMConfigurationPage />;
return (
<>
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
<LLMConfiguration />
</>
);
}

Some files were not shown because too many files have changed in this diff Show More