mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-04 15:15:46 +00:00
Compare commits
23 Commits
content-re
...
worktree-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c6f3cebd8 | ||
|
|
25538f7b61 | ||
|
|
651f70c8c0 | ||
|
|
5e2072dbef | ||
|
|
0f92287880 | ||
|
|
491c3bf906 | ||
|
|
1d898e343f | ||
|
|
0b129e24ea | ||
|
|
267ac5ac60 | ||
|
|
bf80211eae | ||
|
|
135385e57b | ||
|
|
f06630bc1b | ||
|
|
4495df98cf | ||
|
|
0124937aa8 | ||
|
|
aec2d24706 | ||
|
|
16ebb55362 | ||
|
|
ab6c11319e | ||
|
|
05f5b96964 | ||
|
|
f525aa175b | ||
|
|
4ba6e5f735 | ||
|
|
992ad3b8d4 | ||
|
|
a6404f8b3e | ||
|
|
efc49c9f6b |
153
.cursor/skills/onyx-cli/SKILL.md
Normal file
153
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,153 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
The CLI is configured when `~/.config/onyx-cli/config.json` exists and contains an `api_key`. Check with:
|
||||
|
||||
```bash
|
||||
test -s ~/.config/onyx-cli/config.json && echo "configured" || echo "not configured"
|
||||
```
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: http://localhost:3000
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `MessageDeltaEvent` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `StopEvent` | Stream complete |
|
||||
| `ErrorEvent` | Error with `error` message field |
|
||||
| `SearchStartEvent` | Onyx started searching documents |
|
||||
| `CitationEvent` | Source citation with `citation_number` and `document_id` |
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
@@ -114,10 +114,8 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
@@ -122,7 +122,7 @@ repos:
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
entry: bash -c "find . -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
@@ -20,7 +20,6 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -118,38 +117,15 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -109,12 +109,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -303,17 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -331,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -366,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -397,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -435,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -202,6 +202,7 @@ def create_default_image_gen_config_from_api_key(
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -213,29 +213,11 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider: LLMProviderModel | None = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_llm_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} not found"
|
||||
)
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
|
||||
if existing_llm_provider.name != llm_provider_upsert_request.name:
|
||||
raise ValueError(
|
||||
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
|
||||
)
|
||||
else:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM provider with name '{llm_provider_upsert_request.name}'"
|
||||
" already exists"
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -256,7 +238,11 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -320,6 +306,15 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -493,22 +488,6 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -625,13 +604,22 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
model_name,
|
||||
provider.default_model_name, # type: ignore[arg-type]
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -817,6 +805,12 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -135,6 +136,7 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -166,6 +168,7 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -55,12 +52,11 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -237,9 +233,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -269,7 +268,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.model,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -304,7 +303,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
) -> list[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -329,15 +328,7 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -353,44 +344,18 @@ def put_llm_provider(
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Check name constraints
|
||||
# TODO: Once port from name to id is complete, unique name will no longer be required
|
||||
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Renaming providers is not currently supported",
|
||||
)
|
||||
|
||||
found_provider = fetch_existing_llm_provider(
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if found_provider is not None and found_provider is not existing_provider:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} already exists"
|
||||
),
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"LLM Provider with name {llm_provider_upsert_request.name} and "
|
||||
f"id={llm_provider_upsert_request.id} does not exist"
|
||||
),
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -428,6 +393,22 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -457,8 +438,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -488,29 +469,28 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/default")
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
default_model_request: DefaultModel,
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/default-vision")
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
default_model: DefaultModel,
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=default_model.provider_id,
|
||||
vision_model=default_model.model_name,
|
||||
db_session=db_session,
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@@ -536,7 +516,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -565,13 +545,7 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return vision_provider_response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -581,7 +555,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -618,15 +592,7 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return accessible_providers
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -669,7 +635,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -716,51 +682,7 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
# Get the default model and vision model for the persona
|
||||
# TODO: Port persona's over to use ID
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text = DefaultModel.from_model_config(default_text_model)
|
||||
default_vision = DefaultModel.from_model_config(default_vision_model)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider and can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -25,22 +21,50 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
# There is no logic that requires sending the providers default vision model name
|
||||
# We only send for the one that is actually the default
|
||||
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
|
||||
"""Find the default conversation model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns empty string.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
|
||||
return model_config.name
|
||||
return ""
|
||||
|
||||
|
||||
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
|
||||
"""Find the default vision model name for a provider.
|
||||
|
||||
Returns the model name if found, otherwise returns None.
|
||||
"""
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
for flow in model_config.llm_model_flows:
|
||||
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
|
||||
return model_config.name
|
||||
return None
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -56,10 +80,13 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -72,12 +99,24 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -91,17 +130,18 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -117,6 +157,8 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -138,6 +180,16 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = (
|
||||
default_model_name or llm_provider_model.default_model_name or ""
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -150,6 +202,10 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -369,38 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -83,10 +82,6 @@ class Settings(BaseModel):
|
||||
# Default Assistant settings
|
||||
disable_default_assistant: bool | None = False
|
||||
|
||||
# Seat usage - populated by license enforcement when seat limit is exceeded
|
||||
seat_count: int | None = None
|
||||
used_seats: int | None = None
|
||||
|
||||
# OpenSearch migration
|
||||
opensearch_indexing_enabled: bool = False
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -255,18 +254,14 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
provider_name = "DevEnvPresetOpenAI"
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
id=existing.id if existing else None,
|
||||
name=provider_name,
|
||||
name="DevEnvPresetOpenAI",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -278,9 +273,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -28,6 +28,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -40,7 +41,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -58,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -45,14 +44,15 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> LLMProviderView:
|
||||
) -> None:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
return upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -102,11 +102,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -146,11 +152,17 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -182,9 +194,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -192,12 +202,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -231,9 +246,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -241,12 +254,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -279,7 +297,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
provider = upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -287,6 +305,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -302,13 +321,18 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -344,11 +368,17 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model=model_name,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -412,6 +442,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -421,7 +452,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -441,6 +472,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -467,11 +499,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -480,9 +512,6 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -495,7 +524,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -567,6 +596,7 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -575,7 +605,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -49,6 +49,7 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -90,14 +91,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -124,16 +125,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -158,16 +159,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -191,14 +190,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -224,16 +223,14 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -262,14 +259,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -300,6 +297,7 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -324,7 +322,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -332,11 +330,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -364,15 +362,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
_create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -401,7 +399,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -409,13 +407,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -440,17 +438,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -476,7 +474,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
provider = _create_test_provider(
|
||||
_create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -484,10 +482,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -532,8 +530,14 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -542,10 +546,11 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -564,9 +569,14 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -576,9 +586,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -606,13 +616,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
assert provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {db_provider.custom_config}"
|
||||
f"but got {provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -632,10 +642,11 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
view = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -654,9 +665,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -695,7 +706,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -708,10 +719,11 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -730,9 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -15,11 +15,9 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_auto_mode_providers
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -137,6 +135,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -164,8 +163,13 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -234,6 +238,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -312,6 +317,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -320,13 +326,13 @@ class TestAutoModeSyncFeature:
|
||||
)
|
||||
|
||||
# Verify initial state: all models are visible
|
||||
initial_provider = fetch_existing_llm_provider(
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert initial_provider is not None
|
||||
assert initial_provider.is_auto_mode is False
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
|
||||
for mc in initial_provider.model_configurations:
|
||||
for mc in provider.model_configurations:
|
||||
assert (
|
||||
mc.is_visible is True
|
||||
), f"Initial model '{mc.name}' should be visible"
|
||||
@@ -338,12 +344,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=initial_provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -354,15 +360,15 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider_view = fetch_llm_provider_view(
|
||||
provider_name=provider_name, db_session=db_session
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider_view is not None
|
||||
assert provider_view.is_auto_mode is True
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
|
||||
# Build a map of model name -> visibility
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in provider_view.model_configurations
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Models in auto mode config should be visible
|
||||
@@ -382,6 +388,9 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -423,12 +432,8 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -530,6 +535,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -543,7 +549,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -557,6 +563,7 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -577,7 +584,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
@@ -637,6 +644,7 @@ class TestAutoModeMissingFlows:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -693,364 +701,3 @@ class TestAutoModeMissingFlows:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
|
||||
class TestAutoModeTransitionsAndResync:
|
||||
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
|
||||
|
||||
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Disabling auto mode should preserve the current model list and
|
||||
prevent future syncs from altering visibility.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode — models synced from config.
|
||||
2. Update provider to manual mode (is_auto_mode=False).
|
||||
3. Verify all models remain with unchanged visibility.
|
||||
4. Call sync_auto_mode_models with a *different* config.
|
||||
5. Verify fetch_auto_mode_providers excludes this provider, so the
|
||||
periodic task would never call sync on it.
|
||||
"""
|
||||
initial_config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create in auto mode
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=initial_config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility_before = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
|
||||
|
||||
# Step 2: Switch to manual mode
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
is_auto_mode=False,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
is_creation=False,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 3: Models unchanged
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is False
|
||||
visibility_after = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility_after == visibility_before
|
||||
|
||||
# Step 4-5: Provider excluded from auto mode queries
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
auto_provider_ids = {p.id for p in auto_providers}
|
||||
assert provider.id not in auto_provider_ids
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_resync_adds_new_and_hides_removed_models(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the GitHub config changes between syncs, a subsequent sync
|
||||
should add newly listed models and hide models that were removed.
|
||||
|
||||
Steps:
|
||||
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
|
||||
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
|
||||
gpt-4-turbo added).
|
||||
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
|
||||
and visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create with config v1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Re-sync with config v2
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 3: Verify
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# gpt-4o: still in config -> visible
|
||||
assert visibility["gpt-4o"] is True
|
||||
# gpt-4o-mini: removed from config -> hidden (not deleted)
|
||||
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
|
||||
assert visibility["gpt-4o-mini"] is False
|
||||
# gpt-4-turbo: newly added -> visible
|
||||
assert visibility["gpt-4-turbo"] is True
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_is_idempotent(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Running sync twice with the same config should produce zero
|
||||
changes on the second call."""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
# First explicit sync (may report changes if creation already synced)
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Snapshot state after first sync
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
snapshot = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
|
||||
# Second sync — should be a no-op
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert (
|
||||
changes == 0
|
||||
), f"Expected 0 changes on idempotent re-sync, got {changes}"
|
||||
|
||||
# State should be identical
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
current = {
|
||||
mc.name: (mc.is_visible, mc.display_name)
|
||||
for mc in provider.model_configurations
|
||||
}
|
||||
assert current == snapshot
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_default_model_hidden_when_removed_from_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the current default model is removed from the config, sync
|
||||
should hide it. The default model flow row should still exist (it
|
||||
points at the ModelConfiguration), but the model is no longer visible.
|
||||
|
||||
Steps:
|
||||
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
|
||||
2. Set gpt-4o as the global default.
|
||||
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
|
||||
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
|
||||
fetch_default_llm_model still returns a result (the flow row persists).
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=[],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Set gpt-4o as global default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Step 3: Re-sync with config v2 (gpt-4o removed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0
|
||||
|
||||
# Step 4: Verify visibility
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
|
||||
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
|
||||
|
||||
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
|
||||
# but the model is hidden. fetch_default_llm_model filters on
|
||||
# is_visible=True, so it should NOT return gpt-4o.
|
||||
db_session.expire_all()
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert (
|
||||
default_after is None or default_after.name != "gpt-4o"
|
||||
), "Hidden model should not be returned as the default"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -64,6 +64,7 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -153,9 +154,7 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -42,6 +42,7 @@ def _create_llm_provider_and_model(
|
||||
name=provider_name,
|
||||
provider="openai",
|
||||
api_key="test-api-key",
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
|
||||
@@ -434,6 +434,7 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -447,7 +448,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -4,12 +4,10 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -34,6 +32,7 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -66,7 +65,7 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -76,19 +75,9 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -115,7 +104,7 @@ class LLMProviderManager:
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**p) for p in response.json()["providers"]]
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -124,11 +113,7 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -141,30 +126,11 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and (
|
||||
default_model is None or default_model.model_name in model_names
|
||||
)
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
default_text = response.json().get("default_text")
|
||||
if default_text is None:
|
||||
return None
|
||||
return DefaultModel(**default_text)
|
||||
|
||||
@@ -128,7 +128,7 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str | None = None
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
|
||||
@@ -42,10 +42,12 @@ def _create_provider_with_api(
|
||||
llm_provider_data = {
|
||||
"name": name,
|
||||
"provider": provider_type,
|
||||
"default_model_name": default_model,
|
||||
"api_key": "test-api-key-for-auto-mode-testing",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"custom_config": None,
|
||||
"fast_default_model_name": default_model,
|
||||
"is_public": True,
|
||||
"is_auto_mode": is_auto_mode,
|
||||
"groups": [],
|
||||
@@ -70,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json()["providers"]:
|
||||
for provider in response.json():
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
@@ -217,6 +219,15 @@ def test_auto_mode_provider_gets_synced_from_github_config(
|
||||
"is_visible"
|
||||
], "Outdated model should not be visible after sync"
|
||||
|
||||
# Verify default model was set from GitHub config
|
||||
expected_default = (
|
||||
default_model["name"] if isinstance(default_model, dict) else default_model
|
||||
)
|
||||
assert synced_provider["default_model_name"] == expected_default, (
|
||||
f"Default model should be {expected_default}, "
|
||||
f"got {synced_provider['default_model_name']}"
|
||||
)
|
||||
|
||||
|
||||
def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
@@ -262,3 +273,7 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
|
||||
f"Manual mode provider models should not change. "
|
||||
f"Initial: {initial_models}, Current: {current_models}"
|
||||
)
|
||||
|
||||
assert (
|
||||
updated_provider["default_model_name"] == custom_model
|
||||
), f"Manual mode default model should remain {custom_model}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,21 +6,20 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
@@ -42,30 +41,24 @@ def _create_llm_provider(
|
||||
is_public: bool,
|
||||
is_default: bool,
|
||||
) -> LLMProviderModel:
|
||||
_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
provider = LLMProviderModel(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
)
|
||||
if is_default:
|
||||
update_default_provider(_provider.id, default_model_name, db_session)
|
||||
|
||||
provider = db_session.get(LLMProviderModel, _provider.id)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {name} not found")
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
return provider
|
||||
|
||||
|
||||
@@ -277,6 +270,24 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
provider_name=restricted_provider.name,
|
||||
)
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(default_model_config)
|
||||
db_session.flush()
|
||||
db_session.add(
|
||||
LLMModelFlow(
|
||||
model_configuration_id=default_model_config.id,
|
||||
llm_model_flow_type=LLMModelFlowType.CHAT,
|
||||
is_default=True,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
db_session.add(access_group)
|
||||
db_session.flush()
|
||||
@@ -310,19 +321,13 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert (
|
||||
allowed_llm.config.model_name
|
||||
== restricted_provider.model_configurations[0].name
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert (
|
||||
fallback_llm.config.model_name
|
||||
== default_provider.model_configurations[0].name
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -341,7 +346,6 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -361,7 +365,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -376,7 +380,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
@@ -392,7 +396,6 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
|
||||
name="default-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should include the restricted provider since basic_user can access the persona
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
|
||||
|
||||
# Should succeed (persona_id=0 refers to default persona, which is public)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should NOT include the restricted provider since basic_user is not in group2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
|
||||
|
||||
# Should succeed - admins can access any persona
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should include the restricted provider
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
|
||||
# Should return the public provider
|
||||
assert len(providers) > 0
|
||||
|
||||
@@ -9,19 +9,6 @@ from redis.exceptions import RedisError
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
# Fields we assert on across all tests
|
||||
_ASSERT_FIELDS = {
|
||||
"application_status",
|
||||
"ee_features_enabled",
|
||||
"seat_count",
|
||||
"used_seats",
|
||||
}
|
||||
|
||||
|
||||
def _pick(settings: Settings) -> dict:
|
||||
"""Extract only the fields under test from a Settings object."""
|
||||
return settings.model_dump(include=_ASSERT_FIELDS)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_settings() -> Settings:
|
||||
@@ -40,17 +27,17 @@ class TestApplyLicenseStatusToSettings:
|
||||
def test_enforcement_disabled_enables_ee_features(
|
||||
self, base_settings: Settings
|
||||
) -> None:
|
||||
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
|
||||
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
|
||||
|
||||
If we're running the EE apply function, EE code was loaded via
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
|
||||
"""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
assert base_settings.ee_features_enabled is False
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is True
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
|
||||
@@ -59,60 +46,13 @@ class TestApplyLicenseStatusToSettings:
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.ee_features_enabled is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"license_status,used_seats,seats,expected",
|
||||
"license_status,expected_app_status,expected_ee_enabled",
|
||||
[
|
||||
(
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.ACTIVE,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.ACTIVE,
|
||||
10,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
3,
|
||||
10,
|
||||
{
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
},
|
||||
),
|
||||
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
|
||||
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
|
||||
],
|
||||
)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -123,80 +63,25 @@ class TestApplyLicenseStatusToSettings:
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
license_status: ApplicationStatus,
|
||||
used_seats: int,
|
||||
seats: int,
|
||||
expected: dict,
|
||||
license_status: ApplicationStatus | None,
|
||||
expected_app_status: ApplicationStatus,
|
||||
expected_ee_enabled: bool,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = license_status
|
||||
mock_metadata.used_seats = used_seats
|
||||
mock_metadata.seats = seats
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
if license_status is None:
|
||||
mock_get_metadata.return_value = None
|
||||
else:
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = license_status
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == expected
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_seat_limit_exceeded_sets_status_and_counts(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = ApplicationStatus.ACTIVE
|
||||
mock_metadata.used_seats = 15
|
||||
mock_metadata.seats = 10
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
|
||||
"ee_features_enabled": True,
|
||||
"seat_count": 10,
|
||||
"used_seats": 15,
|
||||
}
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_expired_license_takes_precedence_over_seat_limit(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = ApplicationStatus.GATED_ACCESS
|
||||
mock_metadata.used_seats = 15
|
||||
mock_metadata.seats = 10
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.application_status == expected_app_status
|
||||
assert result.ee_features_enabled is expected_ee_enabled
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -220,12 +105,8 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.GATED_ACCESS,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.application_status == ApplicationStatus.GATED_ACCESS
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@@ -249,12 +130,8 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@@ -273,12 +150,8 @@ class TestApplyLicenseStatusToSettings:
|
||||
mock_get_metadata.side_effect = RedisError("Connection failed")
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert _pick(result) == {
|
||||
"application_status": ApplicationStatus.ACTIVE,
|
||||
"ee_features_enabled": False,
|
||||
"seat_count": None,
|
||||
"used_seats": None,
|
||||
}
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
|
||||
@@ -44,6 +44,7 @@ def _build_provider_view(
|
||||
id=1,
|
||||
name="test-provider",
|
||||
provider=provider,
|
||||
default_model_name="test-model",
|
||||
model_configurations=[
|
||||
ModelConfigurationView(
|
||||
name="test-model",
|
||||
@@ -61,6 +62,7 @@ def _build_provider_view(
|
||||
groups=[],
|
||||
personas=[],
|
||||
deployment_name=None,
|
||||
default_vision_model=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
3
cli/.gitignore
vendored
Normal file
3
cli/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
onyx-cli
|
||||
cli
|
||||
onyx.cli
|
||||
118
cli/README.md
Normal file
118
cli/README.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# Onyx CLI
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
Or with uv:
|
||||
|
||||
```shell
|
||||
uv pip install onyx-cli
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
Run the interactive setup:
|
||||
|
||||
```shell
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for your Onyx server URL and API key, tests the connection, and saves config to `~/.config/onyx-cli/config.json`.
|
||||
|
||||
Environment variables override config file values:
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive chat (default)
|
||||
|
||||
```shell
|
||||
onyx-cli
|
||||
```
|
||||
|
||||
### One-shot question
|
||||
|
||||
```shell
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
onyx-cli ask --agent-id 5 "Summarize this topic"
|
||||
onyx-cli ask --json "Hello"
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--agent-id <int>` | Agent ID to use (overrides default) |
|
||||
| `--json` | Output raw NDJSON events instead of plain text |
|
||||
|
||||
### List agents
|
||||
|
||||
```shell
|
||||
onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `chat` | Launch the interactive chat TUI (default) |
|
||||
| `ask` | Ask a one-shot question (non-interactive) |
|
||||
| `agents` | List available agents |
|
||||
| `configure` | Configure server URL and API key |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show help message |
|
||||
| `/new` | Start a new chat session |
|
||||
| `/agent` | List and switch agents |
|
||||
| `/attach <path>` | Attach a file to next message |
|
||||
| `/sessions` | List recent chat sessions |
|
||||
| `/clear` | Clear the chat display |
|
||||
| `/configure` | Re-run connection setup |
|
||||
| `/connectors` | Open connectors in browser |
|
||||
| `/settings` | Open settings in browser |
|
||||
| `/quit` | Exit Onyx CLI |
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| `Enter` | Send message |
|
||||
| `Escape` | Cancel current generation |
|
||||
| `Ctrl+O` | Toggle source citations |
|
||||
| `Ctrl+D` | Quit (press twice) |
|
||||
| `Scroll` / `Shift+Up/Down` | Scroll chat history |
|
||||
| `Page Up` / `Page Down` | Scroll half page |
|
||||
|
||||
## Building from Source
|
||||
|
||||
Requires [Go 1.24+](https://go.dev/dl/).
|
||||
|
||||
```shell
|
||||
cd cli
|
||||
go build -o onyx-cli .
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```shell
|
||||
# Run tests
|
||||
go test ./...
|
||||
|
||||
# Build
|
||||
go build -o onyx-cli .
|
||||
|
||||
# Lint
|
||||
staticcheck ./...
|
||||
```
|
||||
63
cli/cmd/agents.go
Normal file
63
cli/cmd/agents.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAgentsCmd() *cobra.Command {
|
||||
var agentsJSON bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
data, err := json.MarshalIndent(agents, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal agents: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(agents) == 0 {
|
||||
fmt.Println("No agents available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 4, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tNAME\tDESCRIPTION")
|
||||
for _, a := range agents {
|
||||
desc := a.Description
|
||||
if len(desc) > 60 {
|
||||
desc = desc[:57] + "..."
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "%d\t%s\t%s\n", a.ID, a.Name, desc)
|
||||
}
|
||||
_ = w.Flush()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&agentsJSON, "json", false, "Output agents as JSON")
|
||||
|
||||
return cmd
|
||||
}
|
||||
103
cli/cmd/ask.go
Normal file
103
cli/cmd/ask.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
ch := client.SendMessageStream(
|
||||
context.Background(),
|
||||
question,
|
||||
nil,
|
||||
agentID,
|
||||
&parentID,
|
||||
nil,
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
for event := range ch {
|
||||
if askJSON {
|
||||
wrapped := struct {
|
||||
Type string `json:"type"`
|
||||
Event models.StreamEvent `json:"event"`
|
||||
}{
|
||||
Type: event.EventType(),
|
||||
Event: event,
|
||||
}
|
||||
data, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling event: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
if _, ok := event.(models.ErrorEvent); ok {
|
||||
lastErr = fmt.Errorf("%s", event.(models.ErrorEvent).Error)
|
||||
}
|
||||
if _, ok := event.(models.StopEvent); ok {
|
||||
gotStop = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
case models.ErrorEvent:
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
return cmd
|
||||
}
|
||||
33
cli/cmd/chat.go
Normal file
33
cli/cmd/chat.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newChatCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
// First-run: onboarding
|
||||
if !config.ConfigExists() || !cfg.IsConfigured() {
|
||||
result := onboarding.Run(&cfg)
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
_, err := p.Run()
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
19
cli/cmd/configure.go
Normal file
19
cli/cmd/configure.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newConfigureCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
40
cli/cmd/root.go
Normal file
40
cli/cmd/root.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Package cmd implements Cobra CLI commands for the Onyx CLI.
|
||||
package cmd
|
||||
|
||||
import "github.com/spf13/cobra"
|
||||
|
||||
// Version and Commit are set via ldflags at build time.
|
||||
var (
|
||||
Version string
|
||||
Commit string
|
||||
)
|
||||
|
||||
func fullVersion() string {
|
||||
if Commit != "" && Commit != "none" && len(Commit) > 7 {
|
||||
return Version + " (" + Commit[:7] + ")"
|
||||
}
|
||||
return Version
|
||||
}
|
||||
|
||||
// Execute creates and runs the root command.
|
||||
func Execute() error {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "onyx-cli",
|
||||
Short: "Terminal UI for chatting with Onyx",
|
||||
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
|
||||
Version: fullVersion(),
|
||||
}
|
||||
|
||||
// Register subcommands
|
||||
chatCmd := newChatCmd()
|
||||
rootCmd.AddCommand(chatCmd)
|
||||
rootCmd.AddCommand(newAskCmd())
|
||||
rootCmd.AddCommand(newAgentsCmd())
|
||||
rootCmd.AddCommand(newConfigureCmd())
|
||||
rootCmd.AddCommand(newValidateConfigCmd())
|
||||
|
||||
// Default command is chat
|
||||
rootCmd.RunE = chatCmd.RunE
|
||||
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
41
cli/cmd/validate.go
Normal file
41
cli/cmd/validate.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newValidateConfigCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "validate-config",
|
||||
Short: "Validate configuration and test server connection",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
|
||||
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(); err != nil {
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
45
cli/go.mod
Normal file
45
cli/go.mod
Normal file
@@ -0,0 +1,45 @@
|
||||
module github.com/onyx-dot-app/onyx/cli
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.3.4
|
||||
github.com/charmbracelet/glamour v0.8.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/spf13/cobra v1.9.1
|
||||
golang.org/x/term v0.22.0
|
||||
golang.org/x/text v0.34.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yuin/goldmark v1.7.4 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.3 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
)
|
||||
94
cli/go.sum
Normal file
94
cli/go.sum
Normal file
@@ -0,0 +1,94 @@
|
||||
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
|
||||
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
||||
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
|
||||
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
|
||||
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
|
||||
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
|
||||
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
279
cli/internal/api/client.go
Normal file
279
cli/internal/api/client.go
Normal file
@@ -0,0 +1,279 @@
|
||||
// Package api provides the HTTP client for communicating with the Onyx server.
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Client is the Onyx API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client // default 30s timeout for quick requests
|
||||
longHTTPClient *http.Client // 5min timeout for streaming/uploads
|
||||
}
|
||||
|
||||
// NewClient creates a new API client from config.
|
||||
func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
apiKey: cfg.APIKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
},
|
||||
longHTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the client's config.
|
||||
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
|
||||
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
|
||||
c.apiKey = cfg.APIKey
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), method, c.baseURL+path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(method, path string, reqBody any, result any) error {
|
||||
var body io.Reader
|
||||
if reqBody != nil {
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := c.newRequest(method, path, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reqBody != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
return json.NewDecoder(resp.Body).Decode(result)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the server is reachable and credentials are valid.
|
||||
// Returns nil on success, or an error with a descriptive message on failure.
|
||||
func (c *Client) TestConnection() error {
|
||||
// Step 1: Basic reachability
|
||||
req, err := c.newRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
serverHeader := strings.ToLower(resp.Header.Get("Server"))
|
||||
|
||||
if resp.StatusCode == 403 {
|
||||
if strings.Contains(serverHeader, "awselb") || strings.Contains(serverHeader, "amazons3") {
|
||||
return fmt.Errorf("blocked by AWS load balancer (HTTP 403 on all requests).\n Your IP address may not be in the ALB's security group or WAF allowlist")
|
||||
}
|
||||
return fmt.Errorf("HTTP 403 on base URL — the server is blocking all traffic.\n This is likely a firewall, WAF, or IP allowlist restriction")
|
||||
}
|
||||
|
||||
// Step 2: Authenticated check
|
||||
req2, err := c.newRequest("GET", "/api/me", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
|
||||
resp2, err := c.longHTTPClient.Do(req2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
defer func() { _ = resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode == 200 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp2.Body, 300))
|
||||
body := string(bodyBytes)
|
||||
isHTML := strings.HasPrefix(strings.TrimSpace(body), "<")
|
||||
respServer := strings.ToLower(resp2.Header.Get("Server"))
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
if body != "" {
|
||||
detail += fmt.Sprintf("\n Response: %s", body)
|
||||
}
|
||||
return fmt.Errorf("%s", detail)
|
||||
}
|
||||
|
||||
// ListAgents returns visible agents.
|
||||
func (c *Client) ListAgents() ([]models.AgentSummary, error) {
|
||||
var raw []models.AgentSummary
|
||||
if err := c.doJSON("GET", "/api/persona", nil, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []models.AgentSummary
|
||||
for _, p := range raw {
|
||||
if p.IsVisible {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListChatSessions returns recent chat sessions.
|
||||
func (c *Client) ListChatSessions() ([]models.ChatSessionDetails, error) {
|
||||
var resp struct {
|
||||
Sessions []models.ChatSessionDetails `json:"sessions"`
|
||||
}
|
||||
if err := c.doJSON("GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Sessions, nil
|
||||
}
|
||||
|
||||
// GetChatSession returns full details for a session.
|
||||
func (c *Client) GetChatSession(sessionID string) (*models.ChatSessionDetailResponse, error) {
|
||||
var resp models.ChatSessionDetailResponse
|
||||
if err := c.doJSON("GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RenameChatSession renames a session. If name is empty, the backend auto-generates one.
|
||||
func (c *Client) RenameChatSession(sessionID string, name *string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
if name != nil {
|
||||
payload["name"] = *name
|
||||
}
|
||||
var resp struct {
|
||||
NewName string `json:"new_name"`
|
||||
}
|
||||
if err := c.doJSON("PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.NewName, nil
|
||||
}
|
||||
|
||||
// UploadFile uploads a file and returns a file descriptor.
|
||||
func (c *Client) UploadFile(filePath string) (*models.FileDescriptorPayload, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
part, err := writer.CreateFormFile("files", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = writer.Close()
|
||||
|
||||
req, err := c.newRequest("POST", "/api/user/projects/file/upload", &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(body)}
|
||||
}
|
||||
|
||||
var snapshot models.CategorizedFilesSnapshot
|
||||
if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(snapshot.UserFiles) == 0 {
|
||||
return nil, &OnyxAPIError{StatusCode: 400, Detail: "File upload returned no files"}
|
||||
}
|
||||
|
||||
uf := snapshot.UserFiles[0]
|
||||
return &models.FileDescriptorPayload{
|
||||
ID: uf.FileID,
|
||||
Type: uf.ChatFileType,
|
||||
Name: filepath.Base(filePath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StopChatSession sends a stop signal for a streaming session (best-effort).
|
||||
func (c *Client) StopChatSession(sessionID string) {
|
||||
req, err := c.newRequest("POST", "/api/chat/stop-chat-session/"+sessionID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
13
cli/internal/api/errors.go
Normal file
13
cli/internal/api/errors.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package api
|
||||
|
||||
import "fmt"
|
||||
|
||||
// OnyxAPIError is returned when an Onyx API call fails.
|
||||
type OnyxAPIError struct {
|
||||
StatusCode int
|
||||
Detail string
|
||||
}
|
||||
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
136
cli/internal/api/stream.go
Normal file
136
cli/internal/api/stream.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/parser"
|
||||
)
|
||||
|
||||
// StreamEventMsg wraps a StreamEvent for Bubble Tea.
|
||||
type StreamEventMsg struct {
|
||||
Event models.StreamEvent
|
||||
}
|
||||
|
||||
// StreamDoneMsg signals the stream has ended.
|
||||
type StreamDoneMsg struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SendMessageStream starts streaming a chat message response.
|
||||
// It reads NDJSON lines, parses them, and sends events on the returned channel.
|
||||
// The goroutine stops when ctx is cancelled or the stream ends.
|
||||
func (c *Client) SendMessageStream(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
chatSessionID *string,
|
||||
agentID int,
|
||||
parentMessageID *int,
|
||||
fileDescriptors []models.FileDescriptorPayload,
|
||||
) <-chan models.StreamEvent {
|
||||
ch := make(chan models.StreamEvent, 64)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
payload := models.SendMessagePayload{
|
||||
Message: message,
|
||||
ParentMessageID: parentMessageID,
|
||||
FileDescriptors: fileDescriptors,
|
||||
Origin: "api",
|
||||
IncludeCitations: true,
|
||||
Stream: true,
|
||||
}
|
||||
if payload.FileDescriptors == nil {
|
||||
payload.FileDescriptors = []models.FileDescriptorPayload{}
|
||||
}
|
||||
|
||||
if chatSessionID != nil {
|
||||
payload.ChatSessionID = chatSessionID
|
||||
} else {
|
||||
payload.ChatSessionInfo = &models.ChatSessionCreationInfo{AgentID: agentID}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("marshal error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
|
||||
return
|
||||
}
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
req.ContentLength = int64(len(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
bearer := "Bearer " + c.apiKey
|
||||
req.Header.Set("Authorization", bearer)
|
||||
req.Header.Set("X-Onyx-Authorization", bearer)
|
||||
}
|
||||
|
||||
resp, err := c.longHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // cancelled
|
||||
}
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("connection error: %v", err), IsRetryable: true}
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var respBody [4096]byte
|
||||
n, _ := resp.Body.Read(respBody[:])
|
||||
ch <- models.ErrorEvent{
|
||||
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody[:n])),
|
||||
IsRetryable: resp.StatusCode >= 500,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
event := parser.ParseStreamLine(scanner.Text())
|
||||
if event != nil {
|
||||
select {
|
||||
case ch <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && ctx.Err() == nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("stream read error: %v", err), IsRetryable: true}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// WaitForStreamEvent returns a tea.Cmd that reads one event from the channel.
|
||||
// On channel close, it returns StreamDoneMsg.
|
||||
func WaitForStreamEvent(ch <-chan models.StreamEvent) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
event, ok := <-ch
|
||||
if !ok {
|
||||
return StreamDoneMsg{}
|
||||
}
|
||||
return StreamEventMsg{Event: event}
|
||||
}
|
||||
}
|
||||
|
||||
101
cli/internal/config/config.go
Normal file
101
cli/internal/config/config.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
)
|
||||
|
||||
// OnyxCliConfig holds the CLI configuration.
|
||||
type OnyxCliConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a config with default values.
|
||||
func DefaultConfig() OnyxCliConfig {
|
||||
return OnyxCliConfig{
|
||||
ServerURL: "https://cloud.onyx.app",
|
||||
APIKey: "",
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// IsConfigured returns true if the config has an API key.
|
||||
func (c OnyxCliConfig) IsConfigured() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
// configDir returns ~/.config/onyx-cli
|
||||
func configDir() string {
|
||||
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
|
||||
return filepath.Join(xdg, "onyx-cli")
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(".", ".config", "onyx-cli")
|
||||
}
|
||||
return filepath.Join(home, ".config", "onyx-cli")
|
||||
}
|
||||
|
||||
// ConfigFilePath returns the full path to the config file.
|
||||
func ConfigFilePath() string {
|
||||
return filepath.Join(configDir(), "config.json")
|
||||
}
|
||||
|
||||
// ConfigExists checks if the config file exists on disk.
|
||||
func ConfigExists() bool {
|
||||
_, err := os.Stat(ConfigFilePath())
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
if err == nil {
|
||||
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
}
|
||||
if v := os.Getenv(EnvAPIKey); v != "" {
|
||||
cfg.APIKey = v
|
||||
}
|
||||
if v := os.Getenv(EnvAgentID); v != "" {
|
||||
if id, err := strconv.Atoi(v); err == nil {
|
||||
cfg.DefaultAgentID = id
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Save writes the config to disk, creating parent directories if needed.
|
||||
func Save(cfg OnyxCliConfig) error {
|
||||
dir := configDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(ConfigFilePath(), data, 0o600)
|
||||
}
|
||||
215
cli/internal/config/config_test.go
Normal file
215
cli/internal/config/config_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func clearEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID} {
|
||||
t.Setenv(key, "")
|
||||
if err := os.Unsetenv(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeConfig(t *testing.T, dir string, data []byte) {
|
||||
t.Helper()
|
||||
onyxDir := filepath.Join(dir, "onyx-cli")
|
||||
if err := os.MkdirAll(onyxDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(onyxDir, "config.json"), data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default server URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty API key, got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("expected default agent ID 0, got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsConfigured(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.IsConfigured() {
|
||||
t.Error("empty config should not be configured")
|
||||
}
|
||||
cfg.APIKey = "some-key"
|
||||
if !cfg.IsConfigured() {
|
||||
t.Error("config with API key should be configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default URL, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "" {
|
||||
t.Errorf("expected empty key, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://my-onyx.example.com",
|
||||
"api_key": "test-key-123",
|
||||
"default_persona_id": 5,
|
||||
})
|
||||
writeConfig(t, dir, data)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://my-onyx.example.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "test-key-123" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
if cfg.DefaultAgentID != 5 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCorruptFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
writeConfig(t, dir, []byte("not valid json {{{"))
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://cloud.onyx.app" {
|
||||
t.Errorf("expected default URL on corrupt file, got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideServerURL(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvServerURL, "https://env-override.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-override.com" {
|
||||
t.Errorf("got %s", cfg.ServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAPIKey(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAPIKey, "env-key")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.APIKey != "env-key" {
|
||||
t.Errorf("got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "42")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 42 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideInvalidAgentID(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvAgentID, "not-a-number")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.DefaultAgentID != 0 {
|
||||
t.Errorf("got %d", cfg.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverridesFileValues(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://file-url.com",
|
||||
"api_key": "file-key",
|
||||
})
|
||||
writeConfig(t, dir, data)
|
||||
|
||||
t.Setenv(EnvServerURL, "https://env-url.com")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.ServerURL != "https://env-url.com" {
|
||||
t.Errorf("env should override file, got %s", cfg.ServerURL)
|
||||
}
|
||||
if cfg.APIKey != "file-key" {
|
||||
t.Errorf("file value should be kept, got %s", cfg.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndReload(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
cfg := OnyxCliConfig{
|
||||
ServerURL: "https://saved.example.com",
|
||||
APIKey: "saved-key",
|
||||
DefaultAgentID: 10,
|
||||
}
|
||||
if err := Save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded := Load()
|
||||
if loaded.ServerURL != "https://saved.example.com" {
|
||||
t.Errorf("got %s", loaded.ServerURL)
|
||||
}
|
||||
if loaded.APIKey != "saved-key" {
|
||||
t.Errorf("got %s", loaded.APIKey)
|
||||
}
|
||||
if loaded.DefaultAgentID != 10 {
|
||||
t.Errorf("got %d", loaded.DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCreatesParentDirs(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
nested := filepath.Join(dir, "deep", "nested")
|
||||
t.Setenv("XDG_CONFIG_HOME", nested)
|
||||
|
||||
if err := Save(OnyxCliConfig{APIKey: "test"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !ConfigExists() {
|
||||
t.Error("config file should exist after save")
|
||||
}
|
||||
}
|
||||
193
cli/internal/models/events.go
Normal file
193
cli/internal/models/events.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package models
|
||||
|
||||
// StreamEvent is the interface for all parsed stream events.
|
||||
type StreamEvent interface {
|
||||
EventType() string
|
||||
}
|
||||
|
||||
// Event type constants matching the Python StreamEventType enum.
|
||||
const (
|
||||
EventSessionCreated = "session_created"
|
||||
EventMessageIDInfo = "message_id_info"
|
||||
EventStop = "stop"
|
||||
EventError = "error"
|
||||
EventMessageStart = "message_start"
|
||||
EventMessageDelta = "message_delta"
|
||||
EventSearchStart = "search_tool_start"
|
||||
EventSearchQueries = "search_tool_queries_delta"
|
||||
EventSearchDocuments = "search_tool_documents_delta"
|
||||
EventReasoningStart = "reasoning_start"
|
||||
EventReasoningDelta = "reasoning_delta"
|
||||
EventReasoningDone = "reasoning_done"
|
||||
EventCitationInfo = "citation_info"
|
||||
EventOpenURLStart = "open_url_start"
|
||||
EventImageGenStart = "image_generation_start"
|
||||
EventPythonToolStart = "python_tool_start"
|
||||
EventCustomToolStart = "custom_tool_start"
|
||||
EventFileReaderStart = "file_reader_start"
|
||||
EventDeepResearchPlan = "deep_research_plan_start"
|
||||
EventDeepResearchDelta = "deep_research_plan_delta"
|
||||
EventResearchAgentStart = "research_agent_start"
|
||||
EventIntermediateReport = "intermediate_report_start"
|
||||
EventIntermediateReportDt = "intermediate_report_delta"
|
||||
EventUnknown = "unknown"
|
||||
)
|
||||
|
||||
// SessionCreatedEvent is emitted when a new chat session is created.
|
||||
type SessionCreatedEvent struct {
|
||||
ChatSessionID string
|
||||
}
|
||||
|
||||
func (e SessionCreatedEvent) EventType() string { return EventSessionCreated }
|
||||
|
||||
// MessageIDEvent carries the user and agent message IDs.
|
||||
type MessageIDEvent struct {
|
||||
UserMessageID *int
|
||||
ReservedAgentMessageID int
|
||||
}
|
||||
|
||||
func (e MessageIDEvent) EventType() string { return EventMessageIDInfo }
|
||||
|
||||
// StopEvent signals the end of a stream.
|
||||
type StopEvent struct {
|
||||
Placement *Placement
|
||||
StopReason *string
|
||||
}
|
||||
|
||||
func (e StopEvent) EventType() string { return EventStop }
|
||||
|
||||
// ErrorEvent signals an error.
|
||||
type ErrorEvent struct {
|
||||
Placement *Placement
|
||||
Error string
|
||||
StackTrace *string
|
||||
IsRetryable bool
|
||||
}
|
||||
|
||||
func (e ErrorEvent) EventType() string { return EventError }
|
||||
|
||||
// MessageStartEvent signals the beginning of an agent message.
|
||||
type MessageStartEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e MessageStartEvent) EventType() string { return EventMessageStart }
|
||||
|
||||
// MessageDeltaEvent carries a token of agent content.
|
||||
type MessageDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e MessageDeltaEvent) EventType() string { return EventMessageDelta }
|
||||
|
||||
// SearchStartEvent signals the beginning of a search.
|
||||
type SearchStartEvent struct {
|
||||
Placement *Placement
|
||||
IsInternetSearch bool
|
||||
}
|
||||
|
||||
func (e SearchStartEvent) EventType() string { return EventSearchStart }
|
||||
|
||||
// SearchQueriesEvent carries search queries.
|
||||
type SearchQueriesEvent struct {
|
||||
Placement *Placement
|
||||
Queries []string
|
||||
}
|
||||
|
||||
func (e SearchQueriesEvent) EventType() string { return EventSearchQueries }
|
||||
|
||||
// SearchDocumentsEvent carries found documents.
|
||||
type SearchDocumentsEvent struct {
|
||||
Placement *Placement
|
||||
Documents []SearchDoc
|
||||
}
|
||||
|
||||
func (e SearchDocumentsEvent) EventType() string { return EventSearchDocuments }
|
||||
|
||||
// ReasoningStartEvent signals the beginning of a reasoning block.
|
||||
type ReasoningStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningStartEvent) EventType() string { return EventReasoningStart }
|
||||
|
||||
// ReasoningDeltaEvent carries reasoning text.
|
||||
type ReasoningDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Reasoning string
|
||||
}
|
||||
|
||||
func (e ReasoningDeltaEvent) EventType() string { return EventReasoningDelta }
|
||||
|
||||
// ReasoningDoneEvent signals the end of reasoning.
|
||||
type ReasoningDoneEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e ReasoningDoneEvent) EventType() string { return EventReasoningDone }
|
||||
|
||||
// CitationEvent carries citation info.
|
||||
type CitationEvent struct {
|
||||
Placement *Placement
|
||||
CitationNumber int
|
||||
DocumentID string
|
||||
}
|
||||
|
||||
func (e CitationEvent) EventType() string { return EventCitationInfo }
|
||||
|
||||
// ToolStartEvent signals the start of a tool usage.
|
||||
type ToolStartEvent struct {
|
||||
Placement *Placement
|
||||
Type string // The specific event type (e.g. "open_url_start")
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e ToolStartEvent) EventType() string { return e.Type }
|
||||
|
||||
// DeepResearchPlanStartEvent signals the start of a deep research plan.
|
||||
type DeepResearchPlanStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanStartEvent) EventType() string { return EventDeepResearchPlan }
|
||||
|
||||
// DeepResearchPlanDeltaEvent carries deep research plan content.
|
||||
type DeepResearchPlanDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e DeepResearchPlanDeltaEvent) EventType() string { return EventDeepResearchDelta }
|
||||
|
||||
// ResearchAgentStartEvent signals a research sub-task.
|
||||
type ResearchAgentStartEvent struct {
|
||||
Placement *Placement
|
||||
ResearchTask string
|
||||
}
|
||||
|
||||
func (e ResearchAgentStartEvent) EventType() string { return EventResearchAgentStart }
|
||||
|
||||
// IntermediateReportStartEvent signals the start of an intermediate report.
|
||||
type IntermediateReportStartEvent struct {
|
||||
Placement *Placement
|
||||
}
|
||||
|
||||
func (e IntermediateReportStartEvent) EventType() string { return EventIntermediateReport }
|
||||
|
||||
// IntermediateReportDeltaEvent carries intermediate report content.
|
||||
type IntermediateReportDeltaEvent struct {
|
||||
Placement *Placement
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e IntermediateReportDeltaEvent) EventType() string { return EventIntermediateReportDt }
|
||||
|
||||
// UnknownEvent is a catch-all for unrecognized stream data.
|
||||
type UnknownEvent struct {
|
||||
Placement *Placement
|
||||
RawData map[string]any
|
||||
}
|
||||
|
||||
func (e UnknownEvent) EventType() string { return EventUnknown }
|
||||
112
cli/internal/models/models.go
Normal file
112
cli/internal/models/models.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package models defines API request/response types for the Onyx CLI.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// AgentSummary represents an agent from the API.
|
||||
type AgentSummary struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IsDefaultPersona bool `json:"is_default_persona"`
|
||||
IsVisible bool `json:"is_visible"`
|
||||
}
|
||||
|
||||
// ChatSessionSummary is a brief session listing.
|
||||
type ChatSessionSummary struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created time.Time `json:"time_created"`
|
||||
}
|
||||
|
||||
// ChatSessionDetails is a session with timestamps as strings.
|
||||
type ChatSessionDetails struct {
|
||||
ID string `json:"id"`
|
||||
Name *string `json:"name"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
Created string `json:"time_created"`
|
||||
Updated string `json:"time_updated"`
|
||||
}
|
||||
|
||||
// ChatMessageDetail is a single message in a session.
|
||||
type ChatMessageDetail struct {
|
||||
MessageID int `json:"message_id"`
|
||||
ParentMessage *int `json:"parent_message"`
|
||||
LatestChildMessage *int `json:"latest_child_message"`
|
||||
Message string `json:"message"`
|
||||
MessageType string `json:"message_type"`
|
||||
TimeSent string `json:"time_sent"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
// ChatSessionDetailResponse is the full session detail from the API.
|
||||
type ChatSessionDetailResponse struct {
|
||||
ChatSessionID string `json:"chat_session_id"`
|
||||
Description *string `json:"description"`
|
||||
AgentID *int `json:"persona_id"`
|
||||
AgentName *string `json:"persona_name"`
|
||||
Messages []ChatMessageDetail `json:"messages"`
|
||||
}
|
||||
|
||||
// ChatFileType represents a file type for uploads.
|
||||
type ChatFileType string
|
||||
|
||||
const (
|
||||
ChatFileImage ChatFileType = "image"
|
||||
ChatFileDoc ChatFileType = "document"
|
||||
ChatFilePlainText ChatFileType = "plain_text"
|
||||
ChatFileCSV ChatFileType = "csv"
|
||||
)
|
||||
|
||||
// FileDescriptorPayload is a file descriptor for send-message requests.
|
||||
type FileDescriptorPayload struct {
|
||||
ID string `json:"id"`
|
||||
Type ChatFileType `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// UserFileSnapshot represents an uploaded file.
|
||||
type UserFileSnapshot struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FileID string `json:"file_id"`
|
||||
ChatFileType ChatFileType `json:"chat_file_type"`
|
||||
}
|
||||
|
||||
// CategorizedFilesSnapshot is the response from file upload.
|
||||
type CategorizedFilesSnapshot struct {
|
||||
UserFiles []UserFileSnapshot `json:"user_files"`
|
||||
}
|
||||
|
||||
// ChatSessionCreationInfo is included when creating a new session inline.
|
||||
type ChatSessionCreationInfo struct {
|
||||
AgentID int `json:"persona_id"`
|
||||
}
|
||||
|
||||
// SendMessagePayload is the request body for POST /api/chat/send-chat-message.
|
||||
type SendMessagePayload struct {
|
||||
Message string `json:"message"`
|
||||
ChatSessionID *string `json:"chat_session_id,omitempty"`
|
||||
ChatSessionInfo *ChatSessionCreationInfo `json:"chat_session_info,omitempty"`
|
||||
ParentMessageID *int `json:"parent_message_id"`
|
||||
FileDescriptors []FileDescriptorPayload `json:"file_descriptors"`
|
||||
Origin string `json:"origin"`
|
||||
IncludeCitations bool `json:"include_citations"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// SearchDoc represents a document found during search.
|
||||
type SearchDoc struct {
|
||||
DocumentID string `json:"document_id"`
|
||||
SemanticIdentifier string `json:"semantic_identifier"`
|
||||
Link *string `json:"link"`
|
||||
SourceType string `json:"source_type"`
|
||||
}
|
||||
|
||||
// Placement indicates where a stream event belongs in the conversation.
|
||||
type Placement struct {
|
||||
TurnIndex int `json:"turn_index"`
|
||||
TabIndex int `json:"tab_index"`
|
||||
SubTurnIndex *int `json:"sub_turn_index"`
|
||||
}
|
||||
169
cli/internal/onboarding/onboarding.go
Normal file
169
cli/internal/onboarding/onboarding.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Package onboarding handles the first-run setup flow for Onyx CLI.
|
||||
package onboarding
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Aliases for shared styles.
|
||||
var (
|
||||
boldStyle = util.BoldStyle
|
||||
dimStyle = util.DimStyle
|
||||
greenStyle = util.GreenStyle
|
||||
redStyle = util.RedStyle
|
||||
yellowStyle = util.YellowStyle
|
||||
)
|
||||
|
||||
func getTermSize() (int, int) {
|
||||
w, h, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return 80, 24
|
||||
}
|
||||
return w, h
|
||||
}
|
||||
|
||||
// Run executes the interactive onboarding flow.
|
||||
// Returns the validated config, or nil if the user cancels.
|
||||
func Run(existing *config.OnyxCliConfig) *config.OnyxCliConfig {
|
||||
cfg := config.DefaultConfig()
|
||||
if existing != nil {
|
||||
cfg = *existing
|
||||
}
|
||||
|
||||
w, h := getTermSize()
|
||||
fmt.Print(tui.RenderSplashOnboarding(w, h))
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println(" Welcome to " + boldStyle.Render("Onyx CLI") + ".")
|
||||
fmt.Println()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
// Server URL
|
||||
serverURL := prompt(reader, " Onyx server URL", cfg.ServerURL)
|
||||
if serverURL == "" {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(serverURL, "http://") && !strings.HasPrefix(serverURL, "https://") {
|
||||
fmt.Println(" " + redStyle.Render("Server URL must start with http:// or https://"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// API Key
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Need an API key? Press Enter to open the admin panel in your browser,"))
|
||||
fmt.Println(" " + dimStyle.Render("or paste your key below."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey := promptSecret(" API key", cfg.APIKey)
|
||||
|
||||
if apiKey == "" {
|
||||
// Open browser to API key page
|
||||
url := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
|
||||
fmt.Printf("\n Opening %s ...\n", url)
|
||||
util.OpenBrowser(url)
|
||||
fmt.Println(" " + dimStyle.Render("Copy your API key, then paste it here."))
|
||||
fmt.Println()
|
||||
|
||||
apiKey = promptSecret(" API key", "")
|
||||
if apiKey == "" {
|
||||
fmt.Println("\n " + redStyle.Render("No API key provided. Exiting."))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Test connection
|
||||
cfg = config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: cfg.DefaultAgentID,
|
||||
}
|
||||
|
||||
fmt.Println("\n " + yellowStyle.Render("Testing connection..."))
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Connection failed.") + " " + err.Error())
|
||||
fmt.Println()
|
||||
fmt.Println(" " + dimStyle.Render("Run ") + boldStyle.Render("onyx-cli configure") + dimStyle.Render(" to try again."))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
fmt.Println(" " + redStyle.Render("Could not save config: "+err.Error()))
|
||||
return nil
|
||||
}
|
||||
fmt.Println(" " + greenStyle.Render("Connected and authenticated."))
|
||||
fmt.Println()
|
||||
printQuickStart()
|
||||
return &cfg
|
||||
}
|
||||
|
||||
func promptSecret(label, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf("%s %s: ", label, dimStyle.Render("[hidden]"))
|
||||
} else {
|
||||
fmt.Printf("%s: ", label)
|
||||
}
|
||||
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println() // ReadPassword doesn't echo a newline
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
line := strings.TrimSpace(string(password))
|
||||
if line == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func prompt(reader *bufio.Reader, label, defaultVal string) string {
|
||||
if defaultVal != "" {
|
||||
fmt.Printf("%s %s: ", label, dimStyle.Render("["+defaultVal+"]"))
|
||||
} else {
|
||||
fmt.Printf("%s: ", label)
|
||||
}
|
||||
|
||||
line, err := reader.ReadString('\n')
|
||||
// ReadString may return partial data along with an error (e.g. EOF without newline)
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
return line
|
||||
}
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func printQuickStart() {
|
||||
fmt.Println(" " + boldStyle.Render("Quick start"))
|
||||
fmt.Println()
|
||||
fmt.Println(" Just type to chat with your Onyx agent.")
|
||||
fmt.Println()
|
||||
|
||||
rows := [][2]string{
|
||||
{"/help", "Show all commands"},
|
||||
{"/attach", "Attach a file"},
|
||||
{"/agent", "Switch agent"},
|
||||
{"/new", "New conversation"},
|
||||
{"/sessions", "Browse previous chats"},
|
||||
{"Esc", "Cancel generation"},
|
||||
{"Ctrl+D", "Quit"},
|
||||
}
|
||||
for _, r := range rows {
|
||||
fmt.Printf(" %-12s %s\n", boldStyle.Render(r[0]), dimStyle.Render(r[1]))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
248
cli/internal/parser/parser.go
Normal file
248
cli/internal/parser/parser.go
Normal file
@@ -0,0 +1,248 @@
|
||||
// Package parser handles NDJSON stream parsing for Onyx chat responses.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// ParseStreamLine parses a single NDJSON line into a typed StreamEvent.
|
||||
// Returns nil for empty lines or unparseable content.
|
||||
func ParseStreamLine(line string) models.StreamEvent {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &data); err != nil {
|
||||
return models.ErrorEvent{Error: fmt.Sprintf("malformed stream data: %v", err), IsRetryable: false}
|
||||
}
|
||||
|
||||
// Case 1: CreateChatSessionID
|
||||
if _, ok := data["chat_session_id"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
sid, _ := data["chat_session_id"].(string)
|
||||
return models.SessionCreatedEvent{ChatSessionID: sid}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: MessageResponseIDInfo
|
||||
if _, ok := data["reserved_assistant_message_id"]; ok {
|
||||
reservedID := jsonInt(data["reserved_assistant_message_id"])
|
||||
var userMsgID *int
|
||||
if v, ok := data["user_message_id"]; ok && v != nil {
|
||||
id := jsonInt(v)
|
||||
userMsgID = &id
|
||||
}
|
||||
return models.MessageIDEvent{
|
||||
UserMessageID: userMsgID,
|
||||
ReservedAgentMessageID: reservedID,
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: StreamingError (top-level error without placement)
|
||||
if _, ok := data["error"]; ok {
|
||||
if _, hasPlacement := data["placement"]; !hasPlacement {
|
||||
errStr, _ := data["error"].(string)
|
||||
var stackTrace *string
|
||||
if st, ok := data["stack_trace"].(string); ok {
|
||||
stackTrace = &st
|
||||
}
|
||||
isRetryable := true
|
||||
if v, ok := data["is_retryable"].(bool); ok {
|
||||
isRetryable = v
|
||||
}
|
||||
return models.ErrorEvent{
|
||||
Error: errStr,
|
||||
StackTrace: stackTrace,
|
||||
IsRetryable: isRetryable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 4: Packet with placement + obj
|
||||
if rawPlacement, ok := data["placement"]; ok {
|
||||
if rawObj, ok := data["obj"]; ok {
|
||||
placement := parsePlacement(rawPlacement)
|
||||
obj, _ := rawObj.(map[string]any)
|
||||
if obj == nil {
|
||||
return models.UnknownEvent{Placement: placement, RawData: data}
|
||||
}
|
||||
return parsePacketObj(obj, placement)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
return models.UnknownEvent{RawData: data}
|
||||
}
|
||||
|
||||
func parsePlacement(raw interface{}) *models.Placement {
|
||||
m, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
p := &models.Placement{
|
||||
TurnIndex: jsonInt(m["turn_index"]),
|
||||
TabIndex: jsonInt(m["tab_index"]),
|
||||
}
|
||||
if v, ok := m["sub_turn_index"]; ok && v != nil {
|
||||
st := jsonInt(v)
|
||||
p.SubTurnIndex = &st
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func parsePacketObj(obj map[string]any, placement *models.Placement) models.StreamEvent {
|
||||
objType, _ := obj["type"].(string)
|
||||
|
||||
switch objType {
|
||||
case "stop":
|
||||
var reason *string
|
||||
if r, ok := obj["stop_reason"].(string); ok {
|
||||
reason = &r
|
||||
}
|
||||
return models.StopEvent{Placement: placement, StopReason: reason}
|
||||
|
||||
case "error":
|
||||
errMsg := "Unknown error"
|
||||
if e, ok := obj["exception"]; ok {
|
||||
errMsg = toString(e)
|
||||
}
|
||||
return models.ErrorEvent{Placement: placement, Error: errMsg, IsRetryable: true}
|
||||
|
||||
case "message_start":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["final_documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.MessageStartEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "message_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.MessageDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "search_tool_start":
|
||||
isInternet, _ := obj["is_internet_search"].(bool)
|
||||
return models.SearchStartEvent{Placement: placement, IsInternetSearch: isInternet}
|
||||
|
||||
case "search_tool_queries_delta":
|
||||
var queries []string
|
||||
if raw, ok := obj["queries"].([]any); ok {
|
||||
for _, q := range raw {
|
||||
if s, ok := q.(string); ok {
|
||||
queries = append(queries, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return models.SearchQueriesEvent{Placement: placement, Queries: queries}
|
||||
|
||||
case "search_tool_documents_delta":
|
||||
var docs []models.SearchDoc
|
||||
if rawDocs, ok := obj["documents"].([]any); ok {
|
||||
docs = parseSearchDocs(rawDocs)
|
||||
}
|
||||
return models.SearchDocumentsEvent{Placement: placement, Documents: docs}
|
||||
|
||||
case "reasoning_start":
|
||||
return models.ReasoningStartEvent{Placement: placement}
|
||||
|
||||
case "reasoning_delta":
|
||||
reasoning, _ := obj["reasoning"].(string)
|
||||
return models.ReasoningDeltaEvent{Placement: placement, Reasoning: reasoning}
|
||||
|
||||
case "reasoning_done":
|
||||
return models.ReasoningDoneEvent{Placement: placement}
|
||||
|
||||
case "citation_info":
|
||||
return models.CitationEvent{
|
||||
Placement: placement,
|
||||
CitationNumber: jsonInt(obj["citation_number"]),
|
||||
DocumentID: jsonString(obj["document_id"]),
|
||||
}
|
||||
|
||||
case "open_url_start", "image_generation_start", "python_tool_start", "file_reader_start":
|
||||
toolName := strings.ReplaceAll(strings.TrimSuffix(objType, "_start"), "_", " ")
|
||||
toolName = cases.Title(language.English).String(toolName)
|
||||
return models.ToolStartEvent{Placement: placement, Type: objType, ToolName: toolName}
|
||||
|
||||
case "custom_tool_start":
|
||||
toolName := jsonString(obj["tool_name"])
|
||||
if toolName == "" {
|
||||
toolName = "Custom Tool"
|
||||
}
|
||||
return models.ToolStartEvent{Placement: placement, Type: models.EventCustomToolStart, ToolName: toolName}
|
||||
|
||||
case "deep_research_plan_start":
|
||||
return models.DeepResearchPlanStartEvent{Placement: placement}
|
||||
|
||||
case "deep_research_plan_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.DeepResearchPlanDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
case "research_agent_start":
|
||||
task, _ := obj["research_task"].(string)
|
||||
return models.ResearchAgentStartEvent{Placement: placement, ResearchTask: task}
|
||||
|
||||
case "intermediate_report_start":
|
||||
return models.IntermediateReportStartEvent{Placement: placement}
|
||||
|
||||
case "intermediate_report_delta":
|
||||
content, _ := obj["content"].(string)
|
||||
return models.IntermediateReportDeltaEvent{Placement: placement, Content: content}
|
||||
|
||||
default:
|
||||
return models.UnknownEvent{Placement: placement, RawData: obj}
|
||||
}
|
||||
}
|
||||
|
||||
func parseSearchDocs(raw []any) []models.SearchDoc {
|
||||
var docs []models.SearchDoc
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
doc := models.SearchDoc{
|
||||
DocumentID: jsonString(m["document_id"]),
|
||||
SemanticIdentifier: jsonString(m["semantic_identifier"]),
|
||||
SourceType: jsonString(m["source_type"]),
|
||||
}
|
||||
if link, ok := m["link"].(string); ok {
|
||||
doc.Link = &link
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
func jsonInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func jsonString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
419
cli/internal/parser/parser_test.go
Normal file
419
cli/internal/parser/parser_test.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
func TestEmptyLineReturnsNil(t *testing.T) {
|
||||
for _, line := range []string{"", " ", "\n"} {
|
||||
if ParseStreamLine(line) != nil {
|
||||
t.Errorf("expected nil for %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidJSONReturnsErrorEvent(t *testing.T) {
|
||||
for _, line := range []string{"not json", "{broken"} {
|
||||
event := ParseStreamLine(line)
|
||||
if event == nil {
|
||||
t.Errorf("expected ErrorEvent for %q, got nil", line)
|
||||
continue
|
||||
}
|
||||
if _, ok := event.(models.ErrorEvent); !ok {
|
||||
t.Errorf("expected ErrorEvent for %q, got %T", line, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCreated(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"chat_session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SessionCreatedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SessionCreatedEvent, got %T", event)
|
||||
}
|
||||
if e.ChatSessionID != "550e8400-e29b-41d4-a716-446655440000" {
|
||||
t.Errorf("got %s", e.ChatSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfo(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": 1,
|
||||
"reserved_assistant_message_id": 2,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID == nil || *e.UserMessageID != 1 {
|
||||
t.Errorf("expected user_message_id=1")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 2 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageIDInfoNullUserID(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"user_message_id": nil,
|
||||
"reserved_assistant_message_id": 5,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageIDEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageIDEvent, got %T", event)
|
||||
}
|
||||
if e.UserMessageID != nil {
|
||||
t.Error("expected nil user_message_id")
|
||||
}
|
||||
if e.ReservedAgentMessageID != 5 {
|
||||
t.Errorf("got %d", e.ReservedAgentMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelError(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Rate limit exceeded",
|
||||
"stack_trace": "...",
|
||||
"is_retryable": true,
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Rate limit exceeded" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if e.StackTrace == nil || *e.StackTrace != "..." {
|
||||
t.Error("expected stack_trace")
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected retryable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopLevelErrorMinimal(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{
|
||||
"error": "Something broke",
|
||||
})
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Error != "Something broke" {
|
||||
t.Errorf("got %s", e.Error)
|
||||
}
|
||||
if !e.IsRetryable {
|
||||
t.Error("expected default retryable=true")
|
||||
}
|
||||
}
|
||||
|
||||
func makePacket(obj map[string]interface{}, turnIndex, tabIndex int) string {
|
||||
return mustJSON(map[string]interface{}{
|
||||
"placement": map[string]interface{}{"turn_index": turnIndex, "tab_index": tabIndex},
|
||||
"obj": obj,
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopPacket(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop", "stop_reason": "completed"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason == nil || *e.StopReason != "completed" {
|
||||
t.Error("expected stop_reason=completed")
|
||||
}
|
||||
if e.Placement == nil || e.Placement.TurnIndex != 0 {
|
||||
t.Error("expected placement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopPacketNoReason(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "stop"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.StopEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected StopEvent, got %T", event)
|
||||
}
|
||||
if e.StopReason != nil {
|
||||
t.Error("expected nil stop_reason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
_, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageStartWithDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"final_documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "doc1", "semantic_identifier": "Doc 1"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageStartEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 1 || e.Documents[0].DocumentID != "doc1" {
|
||||
t.Error("expected 1 document with id doc1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": "Hello"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Hello" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeltaEmpty(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "message_delta", "content": ""}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "" {
|
||||
t.Errorf("expected empty, got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_start", "is_internet_search": true,
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchStartEvent, got %T", event)
|
||||
}
|
||||
if !e.IsInternetSearch {
|
||||
t.Error("expected internet search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolQueries(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_queries_delta",
|
||||
"queries": []interface{}{"query 1", "query 2"},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchQueriesEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchQueriesEvent, got %T", event)
|
||||
}
|
||||
if len(e.Queries) != 2 || e.Queries[0] != "query 1" {
|
||||
t.Error("unexpected queries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchToolDocuments(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "search_tool_documents_delta",
|
||||
"documents": []interface{}{
|
||||
map[string]interface{}{"document_id": "d1", "semantic_identifier": "First Doc", "link": "http://example.com"},
|
||||
map[string]interface{}{"document_id": "d2", "semantic_identifier": "Second Doc"},
|
||||
},
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.SearchDocumentsEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected SearchDocumentsEvent, got %T", event)
|
||||
}
|
||||
if len(e.Documents) != 2 {
|
||||
t.Errorf("expected 2 docs, got %d", len(e.Documents))
|
||||
}
|
||||
if e.Documents[0].Link == nil || *e.Documents[0].Link != "http://example.com" {
|
||||
t.Error("expected link on first doc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningStartEvent); !ok {
|
||||
t.Fatalf("expected ReasoningStartEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "reasoning_delta", "reasoning": "Let me think...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ReasoningDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReasoningDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Reasoning != "Let me think..." {
|
||||
t.Errorf("got %s", e.Reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningDone(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "reasoning_done"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.ReasoningDoneEvent); !ok {
|
||||
t.Fatalf("expected ReasoningDoneEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationInfo(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "citation_info", "citation_number": 1, "document_id": "doc_abc",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.CitationEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected CitationEvent, got %T", event)
|
||||
}
|
||||
if e.CitationNumber != 1 || e.DocumentID != "doc_abc" {
|
||||
t.Errorf("got %d, %s", e.CitationNumber, e.DocumentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenURLStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "open_url_start"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.Type != "open_url_start" {
|
||||
t.Errorf("got type %s", e.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "python_tool_start", "code": "print('hi')",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "Python Tool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomToolStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "custom_tool_start", "tool_name": "MyTool",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ToolStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolStartEvent, got %T", event)
|
||||
}
|
||||
if e.ToolName != "MyTool" {
|
||||
t.Errorf("got %s", e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepResearchPlanDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "deep_research_plan_delta", "content": "Step 1: ...",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.DeepResearchPlanDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected DeepResearchPlanDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Step 1: ..." {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResearchAgentStart(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "research_agent_start", "research_task": "Find info about X",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.ResearchAgentStartEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResearchAgentStartEvent, got %T", event)
|
||||
}
|
||||
if e.ResearchTask != "Find info about X" {
|
||||
t.Errorf("got %s", e.ResearchTask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntermediateReportDelta(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "intermediate_report_delta", "content": "Report text",
|
||||
}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.IntermediateReportDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected IntermediateReportDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Content != "Report text" {
|
||||
t.Errorf("got %s", e.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownPacketType(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{"type": "section_end"}, 0, 0)
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownTopLevel(t *testing.T) {
|
||||
line := mustJSON(map[string]interface{}{"some_unknown_field": "value"})
|
||||
event := ParseStreamLine(line)
|
||||
if _, ok := event.(models.UnknownEvent); !ok {
|
||||
t.Fatalf("expected UnknownEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlacementPreserved(t *testing.T) {
|
||||
line := makePacket(map[string]interface{}{
|
||||
"type": "message_delta", "content": "x",
|
||||
}, 3, 1)
|
||||
event := ParseStreamLine(line)
|
||||
e, ok := event.(models.MessageDeltaEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected MessageDeltaEvent, got %T", event)
|
||||
}
|
||||
if e.Placement == nil {
|
||||
t.Fatal("expected placement")
|
||||
}
|
||||
if e.Placement.TurnIndex != 3 || e.Placement.TabIndex != 1 {
|
||||
t.Errorf("got turn=%d tab=%d", e.Placement.TurnIndex, e.Placement.TabIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
630
cli/internal/tui/app.go
Normal file
630
cli/internal/tui/app.go
Normal file
@@ -0,0 +1,630 @@
|
||||
// Package tui implements the Bubble Tea TUI for Onyx CLI.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// Model is the root Bubble Tea model.
|
||||
type Model struct {
|
||||
config config.OnyxCliConfig
|
||||
client *api.Client
|
||||
|
||||
viewport *viewport
|
||||
input inputModel
|
||||
status statusBar
|
||||
|
||||
width int
|
||||
height int
|
||||
|
||||
// Chat state
|
||||
chatSessionID *string
|
||||
agentID int
|
||||
agentName string
|
||||
agents []models.AgentSummary
|
||||
parentMessageID *int
|
||||
isStreaming bool
|
||||
streamCancel context.CancelFunc
|
||||
streamCh <-chan models.StreamEvent
|
||||
citations map[int]string
|
||||
attachedFiles []models.FileDescriptorPayload
|
||||
needsRename bool
|
||||
agentStarted bool
|
||||
|
||||
// Quit state
|
||||
quitPending bool
|
||||
splashShown bool
|
||||
initInputReady bool // true once terminal init responses have passed
|
||||
}
|
||||
|
||||
// NewModel creates a new TUI model.
|
||||
func NewModel(cfg config.OnyxCliConfig) Model {
|
||||
client := api.NewClient(cfg)
|
||||
parentID := -1
|
||||
|
||||
return Model{
|
||||
config: cfg,
|
||||
client: client,
|
||||
viewport: newViewport(80),
|
||||
input: newInputModel(),
|
||||
status: newStatusBar(),
|
||||
agentID: cfg.DefaultAgentID,
|
||||
agentName: "Default",
|
||||
parentMessageID: &parentID,
|
||||
citations: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the model.
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return loadAgentsCmd(m.client)
|
||||
}
|
||||
|
||||
// Update handles messages.
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Filter out terminal query responses (OSC 11 background color, cursor
|
||||
// position reports, etc.) that arrive as key events with raw escape content.
|
||||
// These arrive split across multiple key events, so we use a brief window
|
||||
// after startup to swallow them all.
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && !m.initInputReady {
|
||||
// During init, drop ALL key events — they're terminal query responses
|
||||
_ = keyMsg
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.viewport.setWidth(msg.Width)
|
||||
m.status.setWidth(msg.Width)
|
||||
m.input.textInput.Width = msg.Width - 4
|
||||
if !m.splashShown {
|
||||
m.splashShown = true
|
||||
// bottomHeight = sep + input + sep + status = 4 (approx)
|
||||
viewportHeight := msg.Height - 4
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = msg.Height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
// Delay input focus to let terminal query responses flush
|
||||
return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg {
|
||||
return inputReadyMsg{}
|
||||
})
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.MouseMsg:
|
||||
switch msg.Button {
|
||||
case tea.MouseButtonWheelUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
case tea.MouseButtonWheelDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
|
||||
case submitMsg:
|
||||
return m.handleSubmit(msg.text)
|
||||
|
||||
case fileDropMsg:
|
||||
return m.handleFileDrop(msg.path)
|
||||
|
||||
case InitDoneMsg:
|
||||
return m.handleInitDone(msg)
|
||||
|
||||
case api.StreamEventMsg:
|
||||
return m.handleStreamEvent(msg)
|
||||
|
||||
case api.StreamDoneMsg:
|
||||
return m.handleStreamDone(msg)
|
||||
|
||||
case AgentsLoadedMsg:
|
||||
return m.handleAgentsLoaded(msg)
|
||||
|
||||
case SessionsLoadedMsg:
|
||||
return m.handleSessionsLoaded(msg)
|
||||
|
||||
case SessionResumedMsg:
|
||||
return m.handleSessionResumed(msg)
|
||||
|
||||
case FileUploadedMsg:
|
||||
return m.handleFileUploaded(msg)
|
||||
|
||||
case inputReadyMsg:
|
||||
m.initInputReady = true
|
||||
m.input.textInput.Focus()
|
||||
m.input.textInput.SetValue("")
|
||||
return m, m.input.textInput.Cursor.BlinkCmd()
|
||||
|
||||
case resetQuitMsg:
|
||||
m.quitPending = false
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Only forward messages to the text input after it's been focused
|
||||
if m.splashShown {
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// View renders the UI.
|
||||
// viewportHeight returns the number of visible chat rows, accounting for the
|
||||
// dynamic bottom area (separator, menu, file badges, input, status bar).
|
||||
func (m Model) viewportHeight() int {
|
||||
menuHeight := 0
|
||||
if m.input.menuVisible {
|
||||
menuHeight = len(m.input.menuItems)
|
||||
}
|
||||
fileHeight := 0
|
||||
if len(m.input.attachedFiles) > 0 {
|
||||
fileHeight = 1
|
||||
}
|
||||
h := m.height - (1 + menuHeight + fileHeight + 1 + 1 + 1)
|
||||
if h < 1 {
|
||||
return 1
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (m Model) View() string {
|
||||
if m.width == 0 || m.height == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
separator := lipgloss.NewStyle().Foreground(separatorColor).Render(
|
||||
strings.Repeat("─", m.width),
|
||||
)
|
||||
|
||||
menuView := m.input.viewMenu(m.width)
|
||||
viewportHeight := m.viewportHeight()
|
||||
|
||||
var parts []string
|
||||
parts = append(parts, m.viewport.view(viewportHeight))
|
||||
parts = append(parts, separator)
|
||||
if menuView != "" {
|
||||
parts = append(parts, menuView)
|
||||
}
|
||||
parts = append(parts, m.input.viewInput())
|
||||
parts = append(parts, separator)
|
||||
parts = append(parts, m.status.view())
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// handleKey processes keyboard input.
|
||||
func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyEscape:
|
||||
// Cancel streaming or close menu
|
||||
if m.input.menuVisible {
|
||||
m.input.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
// Dismiss picker
|
||||
if m.viewport.pickerActive {
|
||||
m.viewport.pickerActive = false
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyCtrlD:
|
||||
// If streaming, cancel first; require a fresh Ctrl+D pair to quit
|
||||
if m.isStreaming {
|
||||
return m.cancelStream()
|
||||
}
|
||||
if m.quitPending {
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.quitPending = true
|
||||
m.viewport.addInfo("Press Ctrl+D again to quit.")
|
||||
return m, tea.Tick(2*time.Second, func(t time.Time) tea.Msg {
|
||||
return resetQuitMsg{}
|
||||
})
|
||||
|
||||
case tea.KeyCtrlO:
|
||||
m.viewport.showSources = !m.viewport.showSources
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
// If picker is active, handle selection
|
||||
if m.viewport.pickerActive && len(m.viewport.pickerItems) > 0 {
|
||||
item := m.viewport.pickerItems[m.viewport.pickerIndex]
|
||||
m.viewport.pickerActive = false
|
||||
switch m.viewport.pickerType {
|
||||
case pickerSession:
|
||||
return cmdResume(m, item.id)
|
||||
case pickerAgent:
|
||||
return cmdSelectAgent(m, item.id)
|
||||
}
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.viewport.pickerActive {
|
||||
if m.viewport.pickerIndex > 0 {
|
||||
m.viewport.pickerIndex--
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.viewport.pickerActive {
|
||||
if m.viewport.pickerIndex < len(m.viewport.pickerItems)-1 {
|
||||
m.viewport.pickerIndex++
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.viewport.scrollUp(m.viewportHeight() / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.viewport.scrollDown(m.viewportHeight() / 2)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftUp:
|
||||
m.viewport.scrollUp(3)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyShiftDown:
|
||||
m.viewport.scrollDown(3)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Pass to input
|
||||
var cmd tea.Cmd
|
||||
m.input, cmd = m.input.update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) handleSubmit(text string) (tea.Model, tea.Cmd) {
|
||||
if strings.HasPrefix(text, "/") {
|
||||
return handleSlashCommand(m, text)
|
||||
}
|
||||
return m.sendMessage(text)
|
||||
}
|
||||
|
||||
func (m Model) handleFileDrop(path string) (tea.Model, tea.Cmd) {
|
||||
return cmdAttach(m, path)
|
||||
}
|
||||
|
||||
func (m Model) cancelStream() (Model, tea.Cmd) {
|
||||
if m.streamCancel != nil {
|
||||
m.streamCancel()
|
||||
}
|
||||
if m.chatSessionID != nil {
|
||||
sid := *m.chatSessionID
|
||||
go m.client.StopChatSession(sid)
|
||||
}
|
||||
m, cmd := m.finishStream(nil)
|
||||
m.viewport.addInfo("Generation stopped.")
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m Model) sendMessage(message string) (Model, tea.Cmd) {
|
||||
if m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addUserMessage(message)
|
||||
m.viewport.startAgent()
|
||||
|
||||
// Prepare file descriptors
|
||||
fileDescs := make([]models.FileDescriptorPayload, len(m.attachedFiles))
|
||||
copy(fileDescs, m.attachedFiles)
|
||||
m.attachedFiles = nil
|
||||
m.input.clearFiles()
|
||||
|
||||
m.isStreaming = true
|
||||
m.agentStarted = false
|
||||
m.citations = make(map[int]string)
|
||||
m.status.setStreaming(true)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.streamCancel = cancel
|
||||
|
||||
ch := m.client.SendMessageStream(
|
||||
ctx,
|
||||
message,
|
||||
m.chatSessionID,
|
||||
m.agentID,
|
||||
m.parentMessageID,
|
||||
fileDescs,
|
||||
)
|
||||
m.streamCh = ch
|
||||
|
||||
return m, api.WaitForStreamEvent(ch)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamEvent(msg api.StreamEventMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore stale events after cancellation
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
if msg.Event == nil {
|
||||
return m, api.WaitForStreamEvent(m.streamCh)
|
||||
}
|
||||
|
||||
switch e := msg.Event.(type) {
|
||||
case models.SessionCreatedEvent:
|
||||
m.chatSessionID = &e.ChatSessionID
|
||||
m.needsRename = true
|
||||
m.status.setSession(e.ChatSessionID)
|
||||
|
||||
case models.MessageIDEvent:
|
||||
m.parentMessageID = &e.ReservedAgentMessageID
|
||||
|
||||
case models.MessageStartEvent:
|
||||
m.agentStarted = true
|
||||
|
||||
case models.MessageDeltaEvent:
|
||||
m.agentStarted = true
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.SearchStartEvent:
|
||||
if e.IsInternetSearch {
|
||||
m.viewport.addInfo("Web search…")
|
||||
} else {
|
||||
m.viewport.addInfo("Searching…")
|
||||
}
|
||||
|
||||
case models.SearchQueriesEvent:
|
||||
if len(e.Queries) > 0 {
|
||||
queries := e.Queries
|
||||
if len(queries) > 3 {
|
||||
queries = queries[:3]
|
||||
}
|
||||
parts := make([]string, len(queries))
|
||||
for i, q := range queries {
|
||||
parts[i] = "\"" + q + "\""
|
||||
}
|
||||
m.viewport.addInfo("Searching: " + strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
case models.SearchDocumentsEvent:
|
||||
count := len(e.Documents)
|
||||
suffix := "s"
|
||||
if count == 1 {
|
||||
suffix = ""
|
||||
}
|
||||
m.viewport.addInfo("Found " + strconv.Itoa(count) + " document" + suffix)
|
||||
|
||||
case models.ReasoningStartEvent:
|
||||
m.viewport.addInfo("Thinking…")
|
||||
|
||||
case models.ReasoningDeltaEvent:
|
||||
// We don't display reasoning text, just the indicator
|
||||
|
||||
case models.ReasoningDoneEvent:
|
||||
// No-op
|
||||
|
||||
case models.CitationEvent:
|
||||
m.citations[e.CitationNumber] = e.DocumentID
|
||||
|
||||
case models.ToolStartEvent:
|
||||
m.viewport.addInfo("Using " + e.ToolName + "…")
|
||||
|
||||
case models.ResearchAgentStartEvent:
|
||||
m.viewport.addInfo("Researching: " + e.ResearchTask)
|
||||
|
||||
case models.DeepResearchPlanDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.IntermediateReportDeltaEvent:
|
||||
m.viewport.appendToken(e.Content)
|
||||
|
||||
case models.StopEvent:
|
||||
return m.finishStream(nil)
|
||||
|
||||
case models.ErrorEvent:
|
||||
m.viewport.addError(e.Error)
|
||||
return m.finishStream(nil)
|
||||
}
|
||||
|
||||
return m, api.WaitForStreamEvent(m.streamCh)
|
||||
}
|
||||
|
||||
func (m Model) handleStreamDone(msg api.StreamDoneMsg) (tea.Model, tea.Cmd) {
|
||||
// Ignore if already cancelled
|
||||
if !m.isStreaming {
|
||||
return m, nil
|
||||
}
|
||||
return m.finishStream(msg.Err)
|
||||
}
|
||||
|
||||
func (m Model) finishStream(err error) (Model, tea.Cmd) {
|
||||
if m.agentStarted {
|
||||
m.viewport.finishAgent()
|
||||
if len(m.citations) > 0 {
|
||||
m.viewport.addCitations(m.citations)
|
||||
}
|
||||
}
|
||||
m.isStreaming = false
|
||||
m.agentStarted = false
|
||||
m.status.setStreaming(false)
|
||||
m.streamCancel = nil
|
||||
m.streamCh = nil
|
||||
|
||||
// Auto-rename new sessions
|
||||
if m.needsRename && m.chatSessionID != nil {
|
||||
m.needsRename = false
|
||||
sessionID := *m.chatSessionID
|
||||
client := m.client
|
||||
go func() {
|
||||
_, _ = client.RenameChatSession(sessionID, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleInitDone(msg InitDoneMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addWarning("Could not load agents. Using default.")
|
||||
} else {
|
||||
m.agents = msg.Agents
|
||||
for _, p := range m.agents {
|
||||
if p.ID == m.agentID {
|
||||
m.agentName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.status.setServer(m.config.ServerURL)
|
||||
m.status.setAgent(m.agentName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleAgentsLoaded(msg AgentsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load agents: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.agents = msg.Agents
|
||||
if len(m.agents) == 0 {
|
||||
m.viewport.addInfo("No agents available.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select an agent (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for _, p := range m.agents {
|
||||
label := fmt.Sprintf("%d: %s", p.ID, p.Name)
|
||||
if p.ID == m.agentID {
|
||||
label += " *"
|
||||
}
|
||||
desc := p.Description
|
||||
if len(desc) > 50 {
|
||||
desc = desc[:50] + "..."
|
||||
}
|
||||
if desc != "" {
|
||||
label += " - " + desc
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: strconv.Itoa(p.ID),
|
||||
label: label,
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerAgent, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionsLoaded(msg SessionsLoadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load sessions: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
if len(msg.Sessions) == 0 {
|
||||
m.viewport.addInfo("No previous sessions found.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Select a session to resume (Enter to select, Esc to cancel):")
|
||||
|
||||
var items []pickerItem
|
||||
for i, s := range msg.Sessions {
|
||||
if i >= 15 {
|
||||
break
|
||||
}
|
||||
name := "Untitled"
|
||||
if s.Name != nil && *s.Name != "" {
|
||||
name = *s.Name
|
||||
}
|
||||
sid := s.ID
|
||||
if len(sid) > 8 {
|
||||
sid = sid[:8]
|
||||
}
|
||||
items = append(items, pickerItem{
|
||||
id: s.ID,
|
||||
label: sid + " " + name + " (" + s.Created + ")",
|
||||
})
|
||||
}
|
||||
m.viewport.showPicker(pickerSession, items)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleSessionResumed(msg SessionResumedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Could not load session: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Cancel any in-progress stream before replacing the session
|
||||
if m.isStreaming {
|
||||
m, _ = m.cancelStream()
|
||||
}
|
||||
|
||||
detail := msg.Detail
|
||||
m.chatSessionID = &detail.ChatSessionID
|
||||
m.viewport.clearDisplay()
|
||||
m.status.setSession(detail.ChatSessionID)
|
||||
|
||||
if detail.AgentName != nil {
|
||||
m.agentName = *detail.AgentName
|
||||
m.status.setAgent(*detail.AgentName)
|
||||
}
|
||||
if detail.AgentID != nil {
|
||||
m.agentID = *detail.AgentID
|
||||
}
|
||||
|
||||
// Replay messages
|
||||
for _, chatMsg := range detail.Messages {
|
||||
switch chatMsg.MessageType {
|
||||
case "user":
|
||||
m.viewport.addUserMessage(chatMsg.Message)
|
||||
case "assistant":
|
||||
m.viewport.startAgent()
|
||||
m.viewport.appendToken(chatMsg.Message)
|
||||
m.viewport.finishAgent()
|
||||
}
|
||||
}
|
||||
|
||||
// Set parent to last message
|
||||
if len(detail.Messages) > 0 {
|
||||
lastID := detail.Messages[len(detail.Messages)-1].MessageID
|
||||
m.parentMessageID = &lastID
|
||||
}
|
||||
|
||||
desc := "Untitled"
|
||||
if detail.Description != nil && *detail.Description != "" {
|
||||
desc = *detail.Description
|
||||
}
|
||||
m.viewport.addInfo("Resumed session: " + desc)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Model) handleFileUploaded(msg FileUploadedMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Err != nil {
|
||||
m.viewport.addError("Upload failed: " + msg.Err.Error())
|
||||
return m, nil
|
||||
}
|
||||
m.attachedFiles = append(m.attachedFiles, *msg.Descriptor)
|
||||
m.input.addFile(msg.FileName)
|
||||
m.viewport.addInfo("Attached: " + msg.FileName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type inputReadyMsg struct{}
|
||||
type resetQuitMsg struct{}
|
||||
|
||||
202
cli/internal/tui/commands.go
Normal file
202
cli/internal/tui/commands.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/util"
|
||||
)
|
||||
|
||||
// handleSlashCommand dispatches slash commands and returns updated model + cmd.
|
||||
func handleSlashCommand(m Model, text string) (Model, tea.Cmd) {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
command := strings.ToLower(parts[0])
|
||||
arg := ""
|
||||
if len(parts) > 1 {
|
||||
arg = parts[1]
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "/help":
|
||||
m.viewport.addInfo(helpText)
|
||||
return m, nil
|
||||
|
||||
case "/new":
|
||||
return cmdNew(m)
|
||||
|
||||
case "/agent":
|
||||
if arg != "" {
|
||||
return cmdSelectAgent(m, arg)
|
||||
}
|
||||
return cmdShowAgents(m)
|
||||
|
||||
case "/attach":
|
||||
return cmdAttach(m, arg)
|
||||
|
||||
case "/sessions", "/resume":
|
||||
if strings.TrimSpace(arg) != "" {
|
||||
return cmdResume(m, arg)
|
||||
}
|
||||
return cmdSessions(m)
|
||||
|
||||
case "/configure":
|
||||
m.viewport.addInfo("Run 'onyx-cli configure' to change connection settings.")
|
||||
return m, nil
|
||||
|
||||
case "/clear":
|
||||
return cmdNew(m)
|
||||
|
||||
case "/connectors":
|
||||
url := m.config.ServerURL + "/admin/indexing/status"
|
||||
if util.OpenBrowser(url) {
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
} else {
|
||||
m.viewport.addWarning("Failed to open browser. Visit: " + url)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "/settings":
|
||||
url := m.config.ServerURL + "/app/settings/general"
|
||||
if util.OpenBrowser(url) {
|
||||
m.viewport.addInfo("Opened " + url + " in browser")
|
||||
} else {
|
||||
m.viewport.addWarning("Failed to open browser. Visit: " + url)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "/quit":
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.viewport.addWarning(fmt.Sprintf("Unknown command: %s. Type /help for available commands.", command))
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
func cmdNew(m Model) (Model, tea.Cmd) {
|
||||
m.chatSessionID = nil
|
||||
parentID := -1
|
||||
m.parentMessageID = &parentID
|
||||
m.needsRename = false
|
||||
m.citations = nil
|
||||
m.viewport.clearAll()
|
||||
// Re-add splash as a scrollable entry
|
||||
viewportHeight := m.viewportHeight()
|
||||
if viewportHeight < 1 {
|
||||
viewportHeight = m.height
|
||||
}
|
||||
m.viewport.addSplash(viewportHeight)
|
||||
m.status.setSession("")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdShowAgents(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading agents...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return AgentsLoadedMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSelectAgent(m Model, idStr string) (Model, tea.Cmd) {
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(idStr))
|
||||
if err != nil {
|
||||
m.viewport.addWarning("Invalid agent ID. Use a number.")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var target *models.AgentSummary
|
||||
for i := range m.agents {
|
||||
if m.agents[i].ID == pid {
|
||||
target = &m.agents[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
m.viewport.addWarning(fmt.Sprintf("Agent %d not found. Use /agent to see available agents.", pid))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.agentID = target.ID
|
||||
m.agentName = target.Name
|
||||
m.status.setAgent(target.Name)
|
||||
m.viewport.addInfo("Switched to agent: " + target.Name)
|
||||
|
||||
// Save preference
|
||||
m.config.DefaultAgentID = target.ID
|
||||
_ = config.Save(m.config)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cmdAttach(m Model, pathStr string) (Model, tea.Cmd) {
|
||||
if pathStr == "" {
|
||||
m.viewport.addWarning("Usage: /attach <file_path>")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.viewport.addInfo("Uploading " + pathStr + "...")
|
||||
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
fd, err := client.UploadFile(pathStr)
|
||||
if err != nil {
|
||||
return FileUploadedMsg{Err: err, FileName: pathStr}
|
||||
}
|
||||
return FileUploadedMsg{Descriptor: fd, FileName: pathStr}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdSessions(m Model) (Model, tea.Cmd) {
|
||||
m.viewport.addInfo("Loading sessions...")
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
sessions, err := client.ListChatSessions()
|
||||
return SessionsLoadedMsg{Sessions: sessions, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdResume(m Model, sessionIDStr string) (Model, tea.Cmd) {
|
||||
client := m.client
|
||||
return m, func() tea.Msg {
|
||||
// Try to find session by prefix match
|
||||
sessions, err := client.ListChatSessions()
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: err}
|
||||
}
|
||||
|
||||
var targetID string
|
||||
for _, s := range sessions {
|
||||
if strings.HasPrefix(s.ID, sessionIDStr) {
|
||||
targetID = s.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetID == "" {
|
||||
// Try as full UUID
|
||||
targetID = sessionIDStr
|
||||
}
|
||||
|
||||
detail, err := client.GetChatSession(targetID)
|
||||
if err != nil {
|
||||
return SessionResumedMsg{Err: fmt.Errorf("session not found: %s", sessionIDStr)}
|
||||
}
|
||||
return SessionResumedMsg{Detail: detail}
|
||||
}
|
||||
}
|
||||
|
||||
// loadAgentsCmd returns a tea.Cmd that loads agents from the API.
|
||||
func loadAgentsCmd(client *api.Client) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
agents, err := client.ListAgents()
|
||||
return InitDoneMsg{Agents: agents, Err: err}
|
||||
}
|
||||
}
|
||||
24
cli/internal/tui/help.go
Normal file
24
cli/internal/tui/help.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package tui
|
||||
|
||||
const helpText = `Onyx CLI Commands
|
||||
|
||||
/help Show this help message
|
||||
/new Start a new chat session
|
||||
/agent List and switch agents
|
||||
/attach <path> Attach a file to next message
|
||||
/sessions Browse and resume previous sessions
|
||||
/clear Clear the chat display
|
||||
/configure Re-run connection setup
|
||||
/connectors Open connectors page in browser
|
||||
/settings Open Onyx settings in browser
|
||||
/quit Exit Onyx CLI
|
||||
|
||||
Keyboard Shortcuts
|
||||
|
||||
Enter Send message
|
||||
Escape Cancel current generation
|
||||
Ctrl+O Toggle source citations
|
||||
Ctrl+D Quit (press twice)
|
||||
Scroll Up/Down Mouse wheel or Shift+Up/Down
|
||||
Page Up/Down Scroll half page
|
||||
`
|
||||
242
cli/internal/tui/input.go
Normal file
242
cli/internal/tui/input.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// slashCommand defines a slash command with its description.
|
||||
type slashCommand struct {
|
||||
command string
|
||||
description string
|
||||
}
|
||||
|
||||
var slashCommands = []slashCommand{
|
||||
{"/help", "Show help message"},
|
||||
{"/new", "Start a new chat session"},
|
||||
{"/agent", "List and switch agents"},
|
||||
{"/attach", "Attach a file to next message"},
|
||||
{"/sessions", "Browse and resume previous sessions"},
|
||||
{"/clear", "Clear the chat display"},
|
||||
{"/configure", "Re-run connection setup"},
|
||||
{"/connectors", "Open connectors in browser"},
|
||||
{"/settings", "Open settings in browser"},
|
||||
{"/quit", "Exit Onyx CLI"},
|
||||
}
|
||||
|
||||
// Commands that take arguments (filled in with trailing space on Tab/Enter).
|
||||
var argCommands = map[string]bool{
|
||||
"/attach": true,
|
||||
}
|
||||
|
||||
// inputModel manages the text input and slash command menu.
|
||||
type inputModel struct {
|
||||
textInput textinput.Model
|
||||
menuVisible bool
|
||||
menuItems []slashCommand
|
||||
menuIndex int
|
||||
attachedFiles []string
|
||||
}
|
||||
|
||||
func newInputModel() inputModel {
|
||||
ti := textinput.New()
|
||||
ti.Prompt = "" // We render our own prompt in viewInput()
|
||||
ti.Placeholder = "Send a message…"
|
||||
ti.CharLimit = 10000
|
||||
// Don't focus here — focus after first WindowSizeMsg to avoid
|
||||
// capturing terminal init escape sequences as input.
|
||||
|
||||
return inputModel{
|
||||
textInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m inputModel) update(msg tea.Msg) (inputModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) handleKey(msg tea.KeyMsg) (inputModel, tea.Cmd) {
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.menuVisible && m.menuIndex > 0 {
|
||||
m.menuIndex--
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyDown:
|
||||
if m.menuVisible && m.menuIndex < len(m.menuItems)-1 {
|
||||
m.menuIndex++
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyTab:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
} else {
|
||||
m.textInput.SetValue(cmd)
|
||||
m.textInput.SetCursor(len(cmd))
|
||||
}
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyEnter:
|
||||
if m.menuVisible && len(m.menuItems) > 0 {
|
||||
cmd := m.menuItems[m.menuIndex].command
|
||||
if argCommands[cmd] {
|
||||
m.textInput.SetValue(cmd + " ")
|
||||
m.textInput.SetCursor(len(cmd) + 1)
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
// Execute immediately
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: cmd} }
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(m.textInput.Value())
|
||||
if text == "" {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check for file path (drag-and-drop)
|
||||
if dropped := detectFileDrop(text); dropped != "" {
|
||||
m.textInput.SetValue("")
|
||||
return m, func() tea.Msg { return fileDropMsg{path: dropped} }
|
||||
}
|
||||
|
||||
m.textInput.SetValue("")
|
||||
m.menuVisible = false
|
||||
return m, func() tea.Msg { return submitMsg{text: text} }
|
||||
|
||||
case tea.KeyEscape:
|
||||
if m.menuVisible {
|
||||
m.menuVisible = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m = m.updateMenu()
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m inputModel) updateMenu() inputModel {
|
||||
val := strings.TrimSpace(m.textInput.Value())
|
||||
if strings.HasPrefix(val, "/") && !strings.Contains(val, " ") {
|
||||
needle := strings.ToLower(val)
|
||||
var filtered []slashCommand
|
||||
for _, sc := range slashCommands {
|
||||
if strings.HasPrefix(sc.command, needle) {
|
||||
filtered = append(filtered, sc)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
m.menuVisible = true
|
||||
m.menuItems = filtered
|
||||
if m.menuIndex >= len(filtered) {
|
||||
m.menuIndex = 0
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
} else {
|
||||
m.menuVisible = false
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *inputModel) addFile(name string) {
|
||||
m.attachedFiles = append(m.attachedFiles, name)
|
||||
}
|
||||
|
||||
func (m *inputModel) clearFiles() {
|
||||
m.attachedFiles = nil
|
||||
}
|
||||
|
||||
// submitMsg is sent when user submits text.
|
||||
type submitMsg struct {
|
||||
text string
|
||||
}
|
||||
|
||||
// fileDropMsg is sent when a file path is detected.
|
||||
type fileDropMsg struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// detectFileDrop checks if the text looks like a file path.
|
||||
func detectFileDrop(text string) string {
|
||||
cleaned := strings.Trim(text, "'\"")
|
||||
if cleaned == "" {
|
||||
return ""
|
||||
}
|
||||
// Only treat as a file drop if it looks explicitly path-like
|
||||
if !strings.HasPrefix(cleaned, "/") && !strings.HasPrefix(cleaned, "~") &&
|
||||
!strings.HasPrefix(cleaned, "./") && !strings.HasPrefix(cleaned, "../") {
|
||||
return ""
|
||||
}
|
||||
// Expand ~ to home dir
|
||||
if strings.HasPrefix(cleaned, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
cleaned = filepath.Join(home, cleaned[1:])
|
||||
}
|
||||
}
|
||||
abs, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ""
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// viewMenu renders the slash command menu.
|
||||
func (m inputModel) viewMenu(width int) string {
|
||||
if !m.menuVisible || len(m.menuItems) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for i, item := range m.menuItems {
|
||||
prefix := " "
|
||||
if i == m.menuIndex {
|
||||
prefix = "> "
|
||||
}
|
||||
line := prefix + item.command + " " + statusMsgStyle.Render(item.description)
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// viewInput renders the input line with prompt and optional file badges.
|
||||
func (m inputModel) viewInput() string {
|
||||
var parts []string
|
||||
|
||||
if len(m.attachedFiles) > 0 {
|
||||
badges := strings.Join(m.attachedFiles, "] [")
|
||||
parts = append(parts, statusMsgStyle.Render("Attached: ["+badges+"]"))
|
||||
}
|
||||
|
||||
parts = append(parts, inputPrompt+m.textInput.View())
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
46
cli/internal/tui/messages.go
Normal file
46
cli/internal/tui/messages.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
)
|
||||
|
||||
// InitDoneMsg signals that async initialization is complete.
|
||||
type InitDoneMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionsLoadedMsg carries loaded chat sessions.
|
||||
type SessionsLoadedMsg struct {
|
||||
Sessions []models.ChatSessionDetails
|
||||
Err error
|
||||
}
|
||||
|
||||
// SessionResumedMsg carries a loaded session detail.
|
||||
type SessionResumedMsg struct {
|
||||
Detail *models.ChatSessionDetailResponse
|
||||
Err error
|
||||
}
|
||||
|
||||
// FileUploadedMsg carries an uploaded file descriptor.
|
||||
type FileUploadedMsg struct {
|
||||
Descriptor *models.FileDescriptorPayload
|
||||
FileName string
|
||||
Err error
|
||||
}
|
||||
|
||||
// AgentsLoadedMsg carries freshly fetched agents from the API.
|
||||
type AgentsLoadedMsg struct {
|
||||
Agents []models.AgentSummary
|
||||
Err error
|
||||
}
|
||||
|
||||
// InfoMsg is a simple informational message for display.
|
||||
type InfoMsg struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
// ErrorMsg wraps an error for display.
|
||||
type ErrorMsg struct {
|
||||
Err error
|
||||
}
|
||||
79
cli/internal/tui/splash.go
Normal file
79
cli/internal/tui/splash.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const onyxLogo = ` ██████╗ ███╗ ██╗██╗ ██╗██╗ ██╗
|
||||
██╔═══██╗████╗ ██║╚██╗ ██╔╝╚██╗██╔╝
|
||||
██║ ██║██╔██╗ ██║ ╚████╔╝ ╚███╔╝
|
||||
██║ ██║██║╚██╗██║ ╚██╔╝ ██╔██╗
|
||||
╚██████╔╝██║ ╚████║ ██║ ██╔╝ ██╗
|
||||
╚═════╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝`
|
||||
|
||||
const tagline = "Your terminal interface for Onyx"
|
||||
const splashHint = "Type a message to begin · /help for commands"
|
||||
|
||||
// renderSplash renders the splash screen centered for the given dimensions.
|
||||
func renderSplash(width, height int) string {
|
||||
// Render the logo as a single block (don't center individual lines)
|
||||
logo := splashStyle.Render(onyxLogo)
|
||||
|
||||
// Center tagline and hint relative to the logo block width
|
||||
logoWidth := lipgloss.Width(logo)
|
||||
tag := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
taglineStyle.Render(tagline),
|
||||
)
|
||||
hint := lipgloss.NewStyle().Width(logoWidth).Align(lipgloss.Center).Render(
|
||||
hintStyle.Render(splashHint),
|
||||
)
|
||||
|
||||
block := lipgloss.JoinVertical(lipgloss.Left, logo, "", tag, hint)
|
||||
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, block)
|
||||
}
|
||||
|
||||
// RenderSplashOnboarding renders splash for the terminal onboarding screen.
|
||||
func RenderSplashOnboarding(width, height int) string {
|
||||
// Render the logo as a styled block, then center it as a unit
|
||||
styledLogo := splashStyle.Render(onyxLogo)
|
||||
logoWidth := lipgloss.Width(styledLogo)
|
||||
logoLines := strings.Split(styledLogo, "\n")
|
||||
|
||||
logoHeight := len(logoLines)
|
||||
contentHeight := logoHeight + 2 // logo + blank + tagline
|
||||
topPad := (height - contentHeight) / 2
|
||||
if topPad < 1 {
|
||||
topPad = 1
|
||||
}
|
||||
|
||||
// Center the entire logo block horizontally
|
||||
blockPad := (width - logoWidth) / 2
|
||||
if blockPad < 0 {
|
||||
blockPad = 0
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < topPad; i++ {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
for _, line := range logoLines {
|
||||
b.WriteString(strings.Repeat(" ", blockPad))
|
||||
b.WriteString(line)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
b.WriteByte('\n')
|
||||
tagPad := (width - len(tagline)) / 2
|
||||
if tagPad < 0 {
|
||||
tagPad = 0
|
||||
}
|
||||
b.WriteString(strings.Repeat(" ", tagPad))
|
||||
b.WriteString(taglineStyle.Render(tagline))
|
||||
b.WriteByte('\n')
|
||||
|
||||
return b.String()
|
||||
}
|
||||
60
cli/internal/tui/statusbar.go
Normal file
60
cli/internal/tui/statusbar.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// statusBar manages the footer status display.
|
||||
type statusBar struct {
|
||||
agentName string
|
||||
serverURL string
|
||||
sessionID string
|
||||
streaming bool
|
||||
width int
|
||||
}
|
||||
|
||||
func newStatusBar() statusBar {
|
||||
return statusBar{
|
||||
agentName: "Default",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusBar) setAgent(name string) { s.agentName = name }
|
||||
func (s *statusBar) setServer(url string) { s.serverURL = url }
|
||||
func (s *statusBar) setSession(id string) {
|
||||
if len(id) > 8 {
|
||||
id = id[:8]
|
||||
}
|
||||
s.sessionID = id
|
||||
}
|
||||
func (s *statusBar) setStreaming(v bool) { s.streaming = v }
|
||||
func (s *statusBar) setWidth(w int) { s.width = w }
|
||||
|
||||
func (s statusBar) view() string {
|
||||
var leftParts []string
|
||||
if s.serverURL != "" {
|
||||
leftParts = append(leftParts, s.serverURL)
|
||||
}
|
||||
name := s.agentName
|
||||
if name == "" {
|
||||
name = "Default"
|
||||
}
|
||||
leftParts = append(leftParts, name)
|
||||
left := statusBarStyle.Render(strings.Join(leftParts, " · "))
|
||||
|
||||
right := "Ctrl+D to quit"
|
||||
if s.streaming {
|
||||
right = "Esc to cancel"
|
||||
}
|
||||
rightRendered := statusBarStyle.Render(right)
|
||||
|
||||
// Fill space between left and right
|
||||
gap := s.width - lipgloss.Width(left) - lipgloss.Width(rightRendered)
|
||||
if gap < 1 {
|
||||
gap = 1
|
||||
}
|
||||
|
||||
return left + strings.Repeat(" ", gap) + rightRendered
|
||||
}
|
||||
29
cli/internal/tui/styles.go
Normal file
29
cli/internal/tui/styles.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tui
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
var (
|
||||
// Colors
|
||||
accentColor = lipgloss.Color("#6c8ebf")
|
||||
dimColor = lipgloss.Color("#555577")
|
||||
errorColor = lipgloss.Color("#ff5555")
|
||||
splashColor = lipgloss.Color("#7C6AEF")
|
||||
separatorColor = lipgloss.Color("#333355")
|
||||
citationColor = lipgloss.Color("#666688")
|
||||
|
||||
// Styles
|
||||
userPrefixStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
agentDot = lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("◉")
|
||||
infoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#b0b0cc"))
|
||||
dimInfoStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
statusMsgStyle = dimInfoStyle // used for slash menu descriptions, file badges
|
||||
errorStyle = lipgloss.NewStyle().Foreground(errorColor).Bold(true)
|
||||
warnStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
citationStyle = lipgloss.NewStyle().Foreground(citationColor)
|
||||
statusBarStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
inputPrompt = lipgloss.NewStyle().Foreground(accentColor).Render("❯ ")
|
||||
|
||||
splashStyle = lipgloss.NewStyle().Foreground(splashColor).Bold(true)
|
||||
taglineStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#A0A0A0"))
|
||||
hintStyle = lipgloss.NewStyle().Foreground(dimColor)
|
||||
)
|
||||
419
cli/internal/tui/viewport.go
Normal file
419
cli/internal/tui/viewport.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/styles"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// entryKind is the type of chat entry.
|
||||
type entryKind int
|
||||
|
||||
const (
|
||||
entryUser entryKind = iota
|
||||
entryAgent
|
||||
entryInfo
|
||||
entryError
|
||||
entryCitation
|
||||
)
|
||||
|
||||
// chatEntry is a single rendered entry in the chat history.
|
||||
type chatEntry struct {
|
||||
kind entryKind
|
||||
content string // raw content (for agent: the markdown source)
|
||||
rendered string // pre-rendered output
|
||||
citations []string // citation lines (for citation entries)
|
||||
}
|
||||
|
||||
// pickerKind distinguishes what the picker is selecting.
|
||||
type pickerKind int
|
||||
|
||||
const (
|
||||
pickerSession pickerKind = iota
|
||||
pickerAgent
|
||||
)
|
||||
|
||||
// pickerItem is a selectable item in the picker.
|
||||
type pickerItem struct {
|
||||
id string
|
||||
label string
|
||||
}
|
||||
|
||||
// viewport manages the chat display.
|
||||
type viewport struct {
|
||||
entries []chatEntry
|
||||
width int
|
||||
streaming bool
|
||||
streamBuf string
|
||||
showSources bool
|
||||
renderer *glamour.TermRenderer
|
||||
pickerItems []pickerItem
|
||||
pickerActive bool
|
||||
pickerIndex int
|
||||
pickerType pickerKind
|
||||
scrollOffset int // lines scrolled up from bottom (0 = pinned to bottom)
|
||||
lastMaxScroll int // cached from last render for clamping in scrollUp
|
||||
}
|
||||
|
||||
// newMarkdownRenderer creates a Glamour renderer with zero left margin.
|
||||
func newMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
style := styles.DarkStyleConfig
|
||||
zero := uint(0)
|
||||
style.Document.Margin = &zero
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(style),
|
||||
glamour.WithWordWrap(width-4),
|
||||
)
|
||||
return r
|
||||
}
|
||||
|
||||
func newViewport(width int) *viewport {
|
||||
return &viewport{
|
||||
width: width,
|
||||
renderer: newMarkdownRenderer(width),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) addSplash(height int) {
|
||||
splash := renderSplash(v.width, height)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
rendered: splash,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) setWidth(w int) {
|
||||
v.width = w
|
||||
v.renderer = newMarkdownRenderer(w)
|
||||
}
|
||||
|
||||
func (v *viewport) addUserMessage(msg string) {
|
||||
rendered := "\n" + userPrefixStyle.Render("❯ ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryUser,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) startAgent() {
|
||||
v.streaming = true
|
||||
v.streamBuf = ""
|
||||
// Add a blank-line spacer entry before the agent message
|
||||
v.entries = append(v.entries, chatEntry{kind: entryInfo, rendered: ""})
|
||||
}
|
||||
|
||||
func (v *viewport) appendToken(token string) {
|
||||
v.streamBuf += token
|
||||
// Only auto-scroll when already pinned to bottom; preserve user's scroll position
|
||||
if v.scrollOffset == 0 {
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) finishAgent() {
|
||||
if v.streamBuf == "" {
|
||||
v.streaming = false
|
||||
// Remove the blank spacer entry added by startAgent()
|
||||
if len(v.entries) > 0 && v.entries[len(v.entries)-1].kind == entryInfo && v.entries[len(v.entries)-1].rendered == "" {
|
||||
v.entries = v.entries[:len(v.entries)-1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Render markdown with Glamour (zero left margin style)
|
||||
rendered := v.renderMarkdown(v.streamBuf)
|
||||
rendered = strings.TrimLeft(rendered, "\n")
|
||||
rendered = strings.TrimRight(rendered, "\n")
|
||||
lines := strings.Split(rendered, "\n")
|
||||
// Prefix first line with dot, indent continuation lines
|
||||
if len(lines) > 0 {
|
||||
lines[0] = agentDot + " " + lines[0]
|
||||
for i := 1; i < len(lines); i++ {
|
||||
lines[i] = " " + lines[i]
|
||||
}
|
||||
}
|
||||
rendered = strings.Join(lines, "\n")
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryAgent,
|
||||
content: v.streamBuf,
|
||||
rendered: rendered,
|
||||
})
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
}
|
||||
|
||||
func (v *viewport) renderMarkdown(md string) string {
|
||||
if v.renderer == nil {
|
||||
return md
|
||||
}
|
||||
out, err := v.renderer.Render(md)
|
||||
if err != nil {
|
||||
return md
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (v *viewport) addInfo(msg string) {
|
||||
rendered := infoStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryInfo,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addWarning(msg string) {
|
||||
rendered := warnStyle.Render("● " + msg)
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addError(msg string) {
|
||||
rendered := errorStyle.Render("● Error: ") + msg
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryError,
|
||||
content: msg,
|
||||
rendered: rendered,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) addCitations(citations map[int]string) {
|
||||
if len(citations) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]int, 0, len(citations))
|
||||
for k := range citations {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Ints(keys)
|
||||
var parts []string
|
||||
for _, num := range keys {
|
||||
parts = append(parts, fmt.Sprintf("[%d] %s", num, citations[num]))
|
||||
}
|
||||
text := fmt.Sprintf("Sources (%d): %s", len(citations), strings.Join(parts, " "))
|
||||
var citLines []string
|
||||
citLines = append(citLines, text)
|
||||
|
||||
v.entries = append(v.entries, chatEntry{
|
||||
kind: entryCitation,
|
||||
content: text,
|
||||
rendered: citationStyle.Render("● "+text),
|
||||
citations: citLines,
|
||||
})
|
||||
}
|
||||
|
||||
func (v *viewport) showPicker(kind pickerKind, items []pickerItem) {
|
||||
v.pickerItems = items
|
||||
v.pickerType = kind
|
||||
v.pickerActive = true
|
||||
v.pickerIndex = 0
|
||||
}
|
||||
|
||||
func (v *viewport) scrollUp(n int) {
|
||||
v.scrollOffset += n
|
||||
if v.scrollOffset > v.lastMaxScroll {
|
||||
v.scrollOffset = v.lastMaxScroll
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) scrollDown(n int) {
|
||||
v.scrollOffset -= n
|
||||
if v.scrollOffset < 0 {
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewport) clearAll() {
|
||||
v.entries = nil
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
v.pickerItems = nil
|
||||
v.pickerActive = false
|
||||
v.scrollOffset = 0
|
||||
}
|
||||
|
||||
func (v *viewport) clearDisplay() {
|
||||
v.entries = nil
|
||||
v.scrollOffset = 0
|
||||
v.streaming = false
|
||||
v.streamBuf = ""
|
||||
}
|
||||
|
||||
// pickerTitle returns a title for the current picker kind.
|
||||
func (v *viewport) pickerTitle() string {
|
||||
switch v.pickerType {
|
||||
case pickerAgent:
|
||||
return "Select Agent"
|
||||
case pickerSession:
|
||||
return "Resume Session"
|
||||
default:
|
||||
return "Select"
|
||||
}
|
||||
}
|
||||
|
||||
// renderPicker renders the picker as a bordered overlay.
|
||||
func (v *viewport) renderPicker(width, height int) string {
|
||||
title := v.pickerTitle()
|
||||
|
||||
// Determine picker dimensions
|
||||
maxItems := len(v.pickerItems)
|
||||
panelWidth := width - 4
|
||||
if panelWidth < 30 {
|
||||
panelWidth = 30
|
||||
}
|
||||
if panelWidth > 70 {
|
||||
panelWidth = 70
|
||||
}
|
||||
innerWidth := panelWidth - 4 // border + padding
|
||||
|
||||
// Visible window of items (scroll if too many)
|
||||
maxVisible := height - 6 // room for border, title, hint
|
||||
if maxVisible < 3 {
|
||||
maxVisible = 3
|
||||
}
|
||||
if maxVisible > maxItems {
|
||||
maxVisible = maxItems
|
||||
}
|
||||
|
||||
// Calculate scroll window around current index
|
||||
startIdx := 0
|
||||
if v.pickerIndex >= maxVisible {
|
||||
startIdx = v.pickerIndex - maxVisible + 1
|
||||
}
|
||||
endIdx := startIdx + maxVisible
|
||||
if endIdx > maxItems {
|
||||
endIdx = maxItems
|
||||
startIdx = endIdx - maxVisible
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
}
|
||||
|
||||
var itemLines []string
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
item := v.pickerItems[i]
|
||||
label := item.label
|
||||
labelRunes := []rune(label)
|
||||
if len(labelRunes) > innerWidth-4 {
|
||||
label = string(labelRunes[:innerWidth-7]) + "..."
|
||||
}
|
||||
if i == v.pickerIndex {
|
||||
line := lipgloss.NewStyle().Foreground(accentColor).Bold(true).Render("> " + label)
|
||||
itemLines = append(itemLines, line)
|
||||
} else {
|
||||
itemLines = append(itemLines, " "+label)
|
||||
}
|
||||
}
|
||||
|
||||
hint := lipgloss.NewStyle().Foreground(dimColor).Render("↑↓ navigate • enter select • esc cancel")
|
||||
|
||||
body := strings.Join(itemLines, "\n") + "\n\n" + hint
|
||||
|
||||
panel := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accentColor).
|
||||
Padding(1, 2).
|
||||
Width(panelWidth).
|
||||
Render(body)
|
||||
|
||||
titleRendered := lipgloss.NewStyle().
|
||||
Foreground(accentColor).
|
||||
Bold(true).
|
||||
Render(" " + title + " ")
|
||||
|
||||
// Build top border manually to avoid ANSI-corrupted rune slicing.
|
||||
// panelWidth+2 accounts for the left and right border characters.
|
||||
borderColor := lipgloss.NewStyle().Foreground(accentColor)
|
||||
titleWidth := lipgloss.Width(titleRendered)
|
||||
rightDashes := panelWidth + 2 - 3 - titleWidth // total - "╭─" - "╮" - title
|
||||
if rightDashes < 0 {
|
||||
rightDashes = 0
|
||||
}
|
||||
topBorder := borderColor.Render("╭─") + titleRendered +
|
||||
borderColor.Render(strings.Repeat("─", rightDashes)+"╮")
|
||||
|
||||
panelLines := strings.Split(panel, "\n")
|
||||
if len(panelLines) > 0 {
|
||||
panelLines[0] = topBorder
|
||||
}
|
||||
panel = strings.Join(panelLines, "\n")
|
||||
|
||||
// Center the panel in the viewport
|
||||
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, panel)
|
||||
}
|
||||
|
||||
// view renders the full viewport content.
|
||||
func (v *viewport) view(height int) string {
|
||||
// If picker is active, render it as an overlay
|
||||
if v.pickerActive && len(v.pickerItems) > 0 {
|
||||
return v.renderPicker(v.width, height)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
for _, e := range v.entries {
|
||||
if e.kind == entryCitation && !v.showSources {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, e.rendered)
|
||||
}
|
||||
|
||||
// Streaming buffer (plain text, not markdown)
|
||||
if v.streaming && v.streamBuf != "" {
|
||||
bufLines := strings.Split(v.streamBuf, "\n")
|
||||
if len(bufLines) > 0 {
|
||||
bufLines[0] = agentDot + " " + bufLines[0]
|
||||
for i := 1; i < len(bufLines); i++ {
|
||||
bufLines[i] = " " + bufLines[i]
|
||||
}
|
||||
}
|
||||
lines = append(lines, strings.Join(bufLines, "\n"))
|
||||
} else if v.streaming {
|
||||
lines = append(lines, agentDot+" ")
|
||||
}
|
||||
|
||||
content := strings.Join(lines, "\n")
|
||||
contentLines := strings.Split(content, "\n")
|
||||
total := len(contentLines)
|
||||
|
||||
// Cache max scroll for clamping in scrollUp()
|
||||
maxScroll := total - height
|
||||
if maxScroll < 0 {
|
||||
maxScroll = 0
|
||||
}
|
||||
v.lastMaxScroll = maxScroll
|
||||
scrollOffset := v.scrollOffset
|
||||
if scrollOffset > maxScroll {
|
||||
scrollOffset = maxScroll
|
||||
}
|
||||
|
||||
if total <= height {
|
||||
// Content fits — pad with empty lines at top to push content down
|
||||
padding := make([]string, height-total)
|
||||
for i := range padding {
|
||||
padding[i] = ""
|
||||
}
|
||||
contentLines = append(padding, contentLines...)
|
||||
} else {
|
||||
// Show a window: end is (total - scrollOffset), start is (end - height)
|
||||
end := total - scrollOffset
|
||||
start := end - height
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
contentLines = contentLines[start:end]
|
||||
}
|
||||
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
264
cli/internal/tui/viewport_test.go
Normal file
264
cli/internal/tui/viewport_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stripANSI removes ANSI escape sequences for test comparisons.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
func stripANSI(s string) string {
|
||||
return ansiRegex.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
func TestAddUserMessage(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("hello world")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryUser {
|
||||
t.Errorf("expected entryUser, got %d", e.kind)
|
||||
}
|
||||
if e.content != "hello world" {
|
||||
t.Errorf("expected content 'hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "❯") {
|
||||
t.Errorf("expected rendered to contain ❯, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAndFinishAgent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
|
||||
if !v.streaming {
|
||||
t.Error("expected streaming to be true after startAgent")
|
||||
}
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 spacer entry, got %d", len(v.entries))
|
||||
}
|
||||
if v.entries[0].rendered != "" {
|
||||
t.Errorf("expected empty spacer, got %q", v.entries[0].rendered)
|
||||
}
|
||||
|
||||
v.appendToken("Hello ")
|
||||
v.appendToken("world")
|
||||
|
||||
if v.streamBuf != "Hello world" {
|
||||
t.Errorf("expected streamBuf 'Hello world', got %q", v.streamBuf)
|
||||
}
|
||||
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false after finishAgent")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after finish, got %q", v.streamBuf)
|
||||
}
|
||||
if len(v.entries) != 2 {
|
||||
t.Fatalf("expected 2 entries (spacer + agent), got %d", len(v.entries))
|
||||
}
|
||||
|
||||
e := v.entries[1]
|
||||
if e.kind != entryAgent {
|
||||
t.Errorf("expected entryAgent, got %d", e.kind)
|
||||
}
|
||||
if e.content != "Hello world" {
|
||||
t.Errorf("expected content 'Hello world', got %q", e.content)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Hello world") {
|
||||
t.Errorf("expected rendered to contain message text, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentNoPadding(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Test message")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
// First line should not start with plain spaces (ANSI codes are OK)
|
||||
plain := stripANSI(e.rendered)
|
||||
lines := strings.Split(plain, "\n")
|
||||
if strings.HasPrefix(lines[0], " ") {
|
||||
t.Errorf("first line should not start with spaces, got %q", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentMultiline(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.appendToken("Line one\n\nLine three")
|
||||
v.finishAgent()
|
||||
|
||||
e := v.entries[1]
|
||||
plain := stripANSI(e.rendered)
|
||||
// Glamour may merge or reformat lines; just check content is present
|
||||
if !strings.Contains(plain, "Line one") {
|
||||
t.Errorf("expected 'Line one' in rendered, got %q", plain)
|
||||
}
|
||||
if !strings.Contains(plain, "Line three") {
|
||||
t.Errorf("expected 'Line three' in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishAgentEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.startAgent()
|
||||
v.finishAgent()
|
||||
|
||||
if v.streaming {
|
||||
t.Error("expected streaming to be false")
|
||||
}
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected 0 entries (spacer removed), got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddInfo(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("test info")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryInfo {
|
||||
t.Errorf("expected entryInfo, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("info should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddError(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addError("something broke")
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryError {
|
||||
t.Errorf("expected entryError, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "something broke") {
|
||||
t.Errorf("expected error message in rendered, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitations(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{1: "doc-a", 2: "doc-b"})
|
||||
|
||||
if len(v.entries) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(v.entries))
|
||||
}
|
||||
e := v.entries[0]
|
||||
if e.kind != entryCitation {
|
||||
t.Errorf("expected entryCitation, got %d", e.kind)
|
||||
}
|
||||
plain := stripANSI(e.rendered)
|
||||
if !strings.Contains(plain, "Sources (2)") {
|
||||
t.Errorf("expected sources count in rendered, got %q", plain)
|
||||
}
|
||||
if strings.HasPrefix(plain, " ") {
|
||||
t.Errorf("citation should not have leading spaces, got %q", plain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCitationsEmpty(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addCitations(map[int]string{})
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries for empty citations, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationVisibility(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
v.addCitations(map[int]string{1: "doc"})
|
||||
|
||||
v.showSources = false
|
||||
view := v.view(20)
|
||||
plain := stripANSI(view)
|
||||
if strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations hidden when showSources=false")
|
||||
}
|
||||
|
||||
v.showSources = true
|
||||
view = v.view(20)
|
||||
plain = stripANSI(view)
|
||||
if !strings.Contains(plain, "Sources") {
|
||||
t.Error("expected citations visible when showSources=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearAll(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.startAgent()
|
||||
v.appendToken("response")
|
||||
|
||||
v.clearAll()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearAll, got %d", len(v.entries))
|
||||
}
|
||||
if v.streaming {
|
||||
t.Error("expected streaming=false after clearAll")
|
||||
}
|
||||
if v.streamBuf != "" {
|
||||
t.Errorf("expected empty streamBuf after clearAll, got %q", v.streamBuf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDisplay(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addUserMessage("test")
|
||||
v.clearDisplay()
|
||||
|
||||
if len(v.entries) != 0 {
|
||||
t.Errorf("expected no entries after clearDisplay, got %d", len(v.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewPadsShortContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
v.addInfo("hello")
|
||||
|
||||
view := v.view(10)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 10 {
|
||||
t.Errorf("expected 10 lines (padded), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewTruncatesTallContent(t *testing.T) {
|
||||
v := newViewport(80)
|
||||
for i := 0; i < 20; i++ {
|
||||
v.addInfo("line")
|
||||
}
|
||||
|
||||
view := v.view(5)
|
||||
lines := strings.Split(view, "\n")
|
||||
if len(lines) != 5 {
|
||||
t.Errorf("expected 5 lines (truncated), got %d", len(lines))
|
||||
}
|
||||
}
|
||||
29
cli/internal/util/browser.go
Normal file
29
cli/internal/util/browser.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Package util provides shared utility functions.
|
||||
package util
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// OpenBrowser opens the given URL in the user's default browser.
|
||||
// Returns true if the browser was launched successfully.
|
||||
func OpenBrowser(url string) bool {
|
||||
var cmd *exec.Cmd
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
}
|
||||
if cmd != nil {
|
||||
if err := cmd.Start(); err == nil {
|
||||
// Reap the child process to avoid zombies.
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
13
cli/internal/util/styles.go
Normal file
13
cli/internal/util/styles.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Package util provides shared utilities for the Onyx CLI.
|
||||
package util
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// Shared text styles used across the CLI.
|
||||
var (
|
||||
BoldStyle = lipgloss.NewStyle().Bold(true)
|
||||
DimStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#555577"))
|
||||
GreenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00cc66")).Bold(true)
|
||||
RedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff5555")).Bold(true)
|
||||
YellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ffcc00"))
|
||||
)
|
||||
23
cli/main.go
Normal file
23
cli/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cmd.Version = version
|
||||
cmd.Commit = commit
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -42,7 +42,7 @@ import SvgStar from "@opal/icons/star";
|
||||
|
||||
## Usage inside Content
|
||||
|
||||
Tag can be rendered as an accessory inside `Content`'s ContentMd via the `tag` prop:
|
||||
Tag can be rendered as an accessory inside `Content`'s LabelLayout via the `tag` prop:
|
||||
|
||||
```tsx
|
||||
import { Content } from "@opal/layouts";
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* Hoverable — item transitions */
|
||||
.hoverable-item {
|
||||
transition: opacity 150ms ease-in-out;
|
||||
transition: opacity 200ms ease-in-out;
|
||||
}
|
||||
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components/buttons/Button/components";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type ContentLgSizePreset = "headline" | "section";
|
||||
|
||||
interface ContentLgPresetConfig {
|
||||
/** Icon width/height (CSS value). */
|
||||
iconSize: string;
|
||||
/** Tailwind padding class for the icon container. */
|
||||
iconContainerPadding: string;
|
||||
/** Gap between icon container and content (CSS value). */
|
||||
gap: string;
|
||||
/** Tailwind font class for the title. */
|
||||
titleFont: string;
|
||||
/** Title line-height — also used as icon container min-height (CSS value). */
|
||||
lineHeight: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: SizeVariant;
|
||||
/** Tailwind padding class for the edit button container. */
|
||||
editButtonPadding: string;
|
||||
}
|
||||
|
||||
interface ContentLgProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/** Size preset. Default: `"headline"`. */
|
||||
sizePreset?: ContentLgSizePreset;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CONTENT_LG_PRESETS: Record<ContentLgSizePreset, ContentLgPresetConfig> = {
|
||||
headline: {
|
||||
iconSize: "2rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
gap: "0.25rem",
|
||||
titleFont: "font-heading-h2",
|
||||
lineHeight: "2.25rem",
|
||||
editButtonSize: "md",
|
||||
editButtonPadding: "p-1",
|
||||
},
|
||||
section: {
|
||||
iconSize: "1.25rem",
|
||||
iconContainerPadding: "p-1",
|
||||
gap: "0rem",
|
||||
titleFont: "font-heading-h3-muted",
|
||||
lineHeight: "1.75rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0.5",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ContentLg
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ContentLg({
|
||||
sizePreset = "headline",
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
editable,
|
||||
onTitleChange,
|
||||
}: ContentLgProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_LG_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="opal-content-lg" style={{ gap: config.gap }}>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-lg-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className="opal-content-lg-icon"
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-lg-body">
|
||||
<div className="opal-content-lg-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-lg-input-sizer">
|
||||
<span
|
||||
className={cn("opal-content-lg-input-mirror", config.titleFont)}
|
||||
>
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-lg-input",
|
||||
config.titleFont,
|
||||
"text-text-04"
|
||||
)}
|
||||
value={editValue}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
size={1}
|
||||
autoFocus
|
||||
onFocus={(e) => e.currentTarget.select()}
|
||||
onBlur={commit}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
style={{ height: config.lineHeight }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-lg-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{editable && !editing && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-lg-edit-button",
|
||||
config.editButtonPadding
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="internal"
|
||||
size={config.editButtonSize}
|
||||
tooltip="Edit"
|
||||
tooltipSide="right"
|
||||
onClick={startEditing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div className="opal-content-lg-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ContentLg, type ContentLgProps, type ContentLgSizePreset };
|
||||
@@ -1,279 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components/buttons/Button/components";
|
||||
import { Tag, type TagProps } from "@opal/components/Tag/components";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import SvgAlertCircle from "@opal/icons/alert-circle";
|
||||
import SvgAlertTriangle from "@opal/icons/alert-triangle";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import SvgXOctagon from "@opal/icons/x-octagon";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type ContentMdSizePreset = "main-content" | "main-ui" | "secondary";
|
||||
|
||||
type ContentMdAuxIcon = "info-gray" | "info-blue" | "warning" | "error";
|
||||
|
||||
interface ContentMdPresetConfig {
|
||||
iconSize: string;
|
||||
iconContainerPadding: string;
|
||||
iconColorClass: string;
|
||||
titleFont: string;
|
||||
lineHeight: string;
|
||||
gap: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: SizeVariant;
|
||||
editButtonPadding: string;
|
||||
optionalFont: string;
|
||||
/** Aux icon size = lineHeight − 2 × p-0.5. */
|
||||
auxIconSize: string;
|
||||
}
|
||||
|
||||
interface ContentMdProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
|
||||
/** Optional description text below the title. */
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/** When `true`, renders "(Optional)" beside the title. */
|
||||
optional?: boolean;
|
||||
|
||||
/** Auxiliary status icon rendered beside the title. */
|
||||
auxIcon?: ContentMdAuxIcon;
|
||||
|
||||
/** Tag rendered beside the title. */
|
||||
tag?: TagProps;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: ContentMdSizePreset;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
"main-content": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-1",
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-main-content-emphasis",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-content-muted",
|
||||
auxIconSize: "1.25rem",
|
||||
},
|
||||
"main-ui": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
iconColorClass: "text-text-03",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
editButtonSize: "xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-ui-muted",
|
||||
auxIconSize: "1rem",
|
||||
},
|
||||
secondary: {
|
||||
iconSize: "0.75rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "2xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-secondary-action",
|
||||
auxIconSize: "0.75rem",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ContentMd
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const AUX_ICON_CONFIG: Record<
|
||||
ContentMdAuxIcon,
|
||||
{ icon: IconFunctionComponent; colorClass: string }
|
||||
> = {
|
||||
"info-gray": { icon: SvgAlertCircle, colorClass: "text-text-02" },
|
||||
"info-blue": { icon: SvgAlertCircle, colorClass: "text-status-info-05" },
|
||||
warning: { icon: SvgAlertTriangle, colorClass: "text-status-warning-05" },
|
||||
error: { icon: SvgXOctagon, colorClass: "text-status-error-05" },
|
||||
};
|
||||
|
||||
function ContentMd({
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
editable,
|
||||
onTitleChange,
|
||||
optional,
|
||||
auxIcon,
|
||||
tag,
|
||||
sizePreset = "main-ui",
|
||||
}: ContentMdProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_MD_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="opal-content-md" style={{ gap: config.gap }}>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-md-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-md-icon", config.iconColorClass)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-md-body">
|
||||
<div className="opal-content-md-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-md-input-sizer">
|
||||
<span
|
||||
className={cn("opal-content-md-input-mirror", config.titleFont)}
|
||||
>
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-md-input",
|
||||
config.titleFont,
|
||||
"text-text-04"
|
||||
)}
|
||||
value={editValue}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
size={1}
|
||||
autoFocus
|
||||
onFocus={(e) => e.currentTarget.select()}
|
||||
onBlur={commit}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
style={{ height: config.lineHeight }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-md-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{optional && (
|
||||
<span
|
||||
className={cn(config.optionalFont, "text-text-03 shrink-0")}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
(Optional)
|
||||
</span>
|
||||
)}
|
||||
|
||||
{auxIcon &&
|
||||
(() => {
|
||||
const { icon: AuxIcon, colorClass } = AUX_ICON_CONFIG[auxIcon];
|
||||
return (
|
||||
<div
|
||||
className="opal-content-md-aux-icon shrink-0 p-0.5"
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
<AuxIcon
|
||||
className={colorClass}
|
||||
style={{
|
||||
width: config.auxIconSize,
|
||||
height: config.auxIconSize,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{tag && <Tag {...tag} />}
|
||||
|
||||
{editable && !editing && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-md-edit-button",
|
||||
config.editButtonPadding
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="internal"
|
||||
size={config.editButtonSize}
|
||||
tooltip="Edit"
|
||||
tooltipSide="right"
|
||||
onClick={startEditing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div className="opal-content-md-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
ContentMd,
|
||||
type ContentMdProps,
|
||||
type ContentMdSizePreset,
|
||||
type ContentMdAuxIcon,
|
||||
};
|
||||
@@ -1,129 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type ContentSmSizePreset = "main-content" | "main-ui" | "secondary";
|
||||
type ContentSmOrientation = "vertical" | "inline" | "reverse";
|
||||
type ContentSmProminence = "default" | "muted";
|
||||
|
||||
interface ContentSmPresetConfig {
|
||||
/** Icon width/height (CSS value). */
|
||||
iconSize: string;
|
||||
/** Tailwind padding class for the icon container. */
|
||||
iconContainerPadding: string;
|
||||
/** Tailwind font class for the title. */
|
||||
titleFont: string;
|
||||
/** Title line-height — also used as icon container min-height (CSS value). */
|
||||
lineHeight: string;
|
||||
/** Gap between icon container and title (CSS value). */
|
||||
gap: string;
|
||||
}
|
||||
|
||||
/** Props for {@link ContentSm}. Does not support editing or descriptions. */
|
||||
interface ContentSmProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text (read-only — editing is not supported). */
|
||||
title: string;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: ContentSmSizePreset;
|
||||
|
||||
/** Layout orientation. Default: `"inline"`. */
|
||||
orientation?: ContentSmOrientation;
|
||||
|
||||
/** Title prominence. Default: `"default"`. */
|
||||
prominence?: ContentSmProminence;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CONTENT_SM_PRESETS: Record<ContentSmSizePreset, ContentSmPresetConfig> = {
|
||||
"main-content": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-1",
|
||||
titleFont: "font-main-content-body",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
},
|
||||
"main-ui": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
},
|
||||
secondary: {
|
||||
iconSize: "0.75rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ContentSm
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ContentSm({
|
||||
icon: Icon,
|
||||
title,
|
||||
sizePreset = "main-ui",
|
||||
orientation = "inline",
|
||||
prominence = "default",
|
||||
}: ContentSmProps) {
|
||||
const config = CONTENT_SM_PRESETS[sizePreset];
|
||||
const titleColorClass =
|
||||
prominence === "muted" ? "text-text-03" : "text-text-04";
|
||||
|
||||
return (
|
||||
<div
|
||||
className="opal-content-sm"
|
||||
data-orientation={orientation}
|
||||
style={{ gap: config.gap }}
|
||||
>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-sm-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className="opal-content-sm-icon text-text-03"
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-sm-title",
|
||||
config.titleFont,
|
||||
titleColorClass
|
||||
)}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
ContentSm,
|
||||
type ContentSmProps,
|
||||
type ContentSmSizePreset,
|
||||
type ContentSmOrientation,
|
||||
type ContentSmProminence,
|
||||
};
|
||||
@@ -1,258 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components/buttons/Button/components";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type ContentXlSizePreset = "headline" | "section";
|
||||
|
||||
interface ContentXlPresetConfig {
|
||||
/** Icon width/height (CSS value). */
|
||||
iconSize: string;
|
||||
/** Tailwind padding class for the icon container. */
|
||||
iconContainerPadding: string;
|
||||
/** More-icon-1 width/height (CSS value). */
|
||||
moreIcon1Size: string;
|
||||
/** Tailwind padding class for the more-icon-1 container. */
|
||||
moreIcon1ContainerPadding: string;
|
||||
/** More-icon-2 width/height (CSS value). */
|
||||
moreIcon2Size: string;
|
||||
/** Tailwind padding class for the more-icon-2 container. */
|
||||
moreIcon2ContainerPadding: string;
|
||||
/** Tailwind font class for the title. */
|
||||
titleFont: string;
|
||||
/** Title line-height — also used as icon container min-height (CSS value). */
|
||||
lineHeight: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: SizeVariant;
|
||||
/** Tailwind padding class for the edit button container. */
|
||||
editButtonPadding: string;
|
||||
}
|
||||
|
||||
interface ContentXlProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/** Size preset. Default: `"headline"`. */
|
||||
sizePreset?: ContentXlSizePreset;
|
||||
|
||||
/** Optional secondary icon rendered in the icon row. */
|
||||
moreIcon1?: IconFunctionComponent;
|
||||
|
||||
/** Optional tertiary icon rendered in the icon row. */
|
||||
moreIcon2?: IconFunctionComponent;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CONTENT_XL_PRESETS: Record<ContentXlSizePreset, ContentXlPresetConfig> = {
|
||||
headline: {
|
||||
iconSize: "2rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
moreIcon1Size: "1rem",
|
||||
moreIcon1ContainerPadding: "p-0.5",
|
||||
moreIcon2Size: "2rem",
|
||||
moreIcon2ContainerPadding: "p-0.5",
|
||||
titleFont: "font-heading-h2",
|
||||
lineHeight: "2.25rem",
|
||||
editButtonSize: "md",
|
||||
editButtonPadding: "p-1",
|
||||
},
|
||||
section: {
|
||||
iconSize: "1.5rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
moreIcon1Size: "0.75rem",
|
||||
moreIcon1ContainerPadding: "p-0.5",
|
||||
moreIcon2Size: "1.5rem",
|
||||
moreIcon2ContainerPadding: "p-0.5",
|
||||
titleFont: "font-heading-h3",
|
||||
lineHeight: "1.75rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0.5",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ContentXl
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ContentXl({
|
||||
sizePreset = "headline",
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
editable,
|
||||
onTitleChange,
|
||||
moreIcon1: MoreIcon1,
|
||||
moreIcon2: MoreIcon2,
|
||||
}: ContentXlProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_XL_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="opal-content-xl">
|
||||
{(Icon || MoreIcon1 || MoreIcon2) && (
|
||||
<div className="opal-content-xl-icon-row">
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-xl-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className="opal-content-xl-icon"
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{MoreIcon1 && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-xl-more-icon-container shrink-0",
|
||||
config.moreIcon1ContainerPadding
|
||||
)}
|
||||
>
|
||||
<MoreIcon1
|
||||
className="opal-content-xl-icon"
|
||||
style={{
|
||||
width: config.moreIcon1Size,
|
||||
height: config.moreIcon1Size,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{MoreIcon2 && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-xl-more-icon-container shrink-0",
|
||||
config.moreIcon2ContainerPadding
|
||||
)}
|
||||
>
|
||||
<MoreIcon2
|
||||
className="opal-content-xl-icon"
|
||||
style={{
|
||||
width: config.moreIcon2Size,
|
||||
height: config.moreIcon2Size,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-xl-body">
|
||||
<div className="opal-content-xl-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-xl-input-sizer">
|
||||
<span
|
||||
className={cn("opal-content-xl-input-mirror", config.titleFont)}
|
||||
>
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-xl-input",
|
||||
config.titleFont,
|
||||
"text-text-04"
|
||||
)}
|
||||
value={editValue}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
size={1}
|
||||
autoFocus
|
||||
onFocus={(e) => e.currentTarget.select()}
|
||||
onBlur={commit}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
style={{ height: config.lineHeight }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-xl-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{editable && !editing && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-xl-edit-button",
|
||||
config.editButtonPadding
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="internal"
|
||||
size={config.editButtonSize}
|
||||
tooltip="Edit"
|
||||
tooltipSide="right"
|
||||
onClick={startEditing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div className="opal-content-xl-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ContentXl, type ContentXlProps, type ContentXlSizePreset };
|
||||
@@ -8,21 +8,14 @@ A two-axis layout component for displaying icon + title + description rows. Rout
|
||||
|
||||
### `sizePreset` — controls sizing (icon, padding, gap, font)
|
||||
|
||||
#### ContentXl presets (variant="heading")
|
||||
|
||||
| Preset | Icon | Icon padding | moreIcon1 | mI1 padding | moreIcon2 | mI2 padding | Title font | Line-height |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| `headline` | 2rem (32px) | `p-0.5` (2px) | 1rem (16px) | `p-0.5` (2px) | 2rem (32px) | `p-0.5` (2px) | `font-heading-h2` | 2.25rem (36px) |
|
||||
| `section` | 1.5rem (24px) | `p-0.5` (2px) | 0.75rem (12px) | `p-0.5` (2px) | 1.5rem (24px) | `p-0.5` (2px) | `font-heading-h3` | 1.75rem (28px) |
|
||||
|
||||
#### ContentLg presets (variant="section")
|
||||
#### HeadingLayout presets
|
||||
|
||||
| Preset | Icon | Icon padding | Gap | Title font | Line-height |
|
||||
|---|---|---|---|---|---|
|
||||
| `headline` | 2rem (32px) | `p-0.5` (2px) | 0.25rem (4px) | `font-heading-h2` | 2.25rem (36px) |
|
||||
| `section` | 1.25rem (20px) | `p-1` (4px) | 0rem | `font-heading-h3-muted` | 1.75rem (28px) |
|
||||
| `section` | 1.25rem (20px) | `p-1` (4px) | 0rem | `font-heading-h3` | 1.75rem (28px) |
|
||||
|
||||
#### ContentMd presets
|
||||
#### LabelLayout presets
|
||||
|
||||
| Preset | Icon | Icon padding | Icon color | Gap | Title font | Line-height |
|
||||
|---|---|---|---|---|---|---|
|
||||
@@ -36,18 +29,18 @@ A two-axis layout component for displaying icon + title + description rows. Rout
|
||||
|
||||
| variant | Description |
|
||||
|---|---|
|
||||
| `heading` | Icon on **top** (flex-col) — ContentXl |
|
||||
| `section` | Icon **inline** (flex-row) — ContentLg or ContentMd |
|
||||
| `body` | Body text layout — ContentSm |
|
||||
| `heading` | Icon on **top** (flex-col) — HeadingLayout |
|
||||
| `section` | Icon **inline** (flex-row) — HeadingLayout or LabelLayout |
|
||||
| `body` | Body text layout — BodyLayout (future) |
|
||||
|
||||
### Valid Combinations -> Internal Routing
|
||||
|
||||
| sizePreset | variant | Routes to |
|
||||
|---|---|---|
|
||||
| `headline` / `section` | `heading` | **ContentXl** (icon on top) |
|
||||
| `headline` / `section` | `section` | **ContentLg** (icon inline) |
|
||||
| `main-content` / `main-ui` / `secondary` | `section` | **ContentMd** |
|
||||
| `main-content` / `main-ui` / `secondary` | `body` | **ContentSm** |
|
||||
| `headline` / `section` | `heading` | **HeadingLayout** (icon on top) |
|
||||
| `headline` / `section` | `section` | **HeadingLayout** (icon inline) |
|
||||
| `main-content` / `main-ui` / `secondary` | `section` | **LabelLayout** |
|
||||
| `main-content` / `main-ui` / `secondary` | `body` | BodyLayout (future) |
|
||||
|
||||
Invalid combinations (e.g. `sizePreset="headline" + variant="body"`) are excluded at the type level.
|
||||
|
||||
@@ -62,20 +55,14 @@ Invalid combinations (e.g. `sizePreset="headline" + variant="body"`) are exclude
|
||||
| `description` | `string` | — | Optional description below the title |
|
||||
| `editable` | `boolean` | `false` | Enable inline editing of the title |
|
||||
| `onTitleChange` | `(newTitle: string) => void` | — | Called when user commits an edit |
|
||||
| `moreIcon1` | `IconFunctionComponent` | — | Secondary icon in icon row (ContentXl only) |
|
||||
| `moreIcon2` | `IconFunctionComponent` | — | Tertiary icon in icon row (ContentXl only) |
|
||||
|
||||
## Internal Layouts
|
||||
|
||||
### ContentXl
|
||||
### HeadingLayout
|
||||
|
||||
For `headline` / `section` presets with `variant="heading"`. Icon row on top (flex-col), supports `moreIcon1` and `moreIcon2` in the icon row. Description is always `font-secondary-body text-text-03`.
|
||||
For `headline` / `section` presets. Supports `variant="heading"` (icon on top) and `variant="section"` (icon inline). Description is always `font-secondary-body text-text-03`.
|
||||
|
||||
### ContentLg
|
||||
|
||||
For `headline` / `section` presets with `variant="section"`. Always inline (flex-row). Description is always `font-secondary-body text-text-03`.
|
||||
|
||||
### ContentMd
|
||||
### LabelLayout
|
||||
|
||||
For `main-content` / `main-ui` / `secondary` presets. Always inline. Both `icon` and `description` are optional. Description is always `font-secondary-body text-text-03`.
|
||||
|
||||
@@ -85,7 +72,7 @@ For `main-content` / `main-ui` / `secondary` presets. Always inline. Both `icon`
|
||||
import { Content } from "@opal/layouts";
|
||||
import SvgSearch from "@opal/icons/search";
|
||||
|
||||
// ContentXl — headline, icon on top
|
||||
// HeadingLayout — headline, icon on top
|
||||
<Content
|
||||
icon={SvgSearch}
|
||||
sizePreset="headline"
|
||||
@@ -94,17 +81,7 @@ import SvgSearch from "@opal/icons/search";
|
||||
description="Configure your agent's behavior"
|
||||
/>
|
||||
|
||||
// ContentXl — with more icons
|
||||
<Content
|
||||
icon={SvgSearch}
|
||||
sizePreset="headline"
|
||||
variant="heading"
|
||||
title="Agent Settings"
|
||||
moreIcon1={SvgStar}
|
||||
moreIcon2={SvgLock}
|
||||
/>
|
||||
|
||||
// ContentLg — section, icon inline
|
||||
// HeadingLayout — section, icon inline
|
||||
<Content
|
||||
icon={SvgSearch}
|
||||
sizePreset="section"
|
||||
@@ -113,7 +90,7 @@ import SvgSearch from "@opal/icons/search";
|
||||
description="Connected integrations"
|
||||
/>
|
||||
|
||||
// ContentMd — with icon and description
|
||||
// LabelLayout — with icon and description
|
||||
<Content
|
||||
icon={SvgSearch}
|
||||
sizePreset="main-ui"
|
||||
@@ -121,7 +98,7 @@ import SvgSearch from "@opal/icons/search";
|
||||
description="Agent system prompt"
|
||||
/>
|
||||
|
||||
// ContentMd — title only (no icon, no description)
|
||||
// LabelLayout — title only (no icon, no description)
|
||||
<Content
|
||||
sizePreset="main-content"
|
||||
title="Featured Agent"
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
import "@opal/layouts/Content/styles.css";
|
||||
import {
|
||||
ContentSm,
|
||||
type ContentSmOrientation,
|
||||
type ContentSmProminence,
|
||||
} from "@opal/layouts/Content/ContentSm";
|
||||
BodyLayout,
|
||||
type BodyOrientation,
|
||||
type BodyProminence,
|
||||
} from "@opal/layouts/Content/BodyLayout";
|
||||
import {
|
||||
ContentXl,
|
||||
type ContentXlProps,
|
||||
} from "@opal/layouts/Content/ContentXl";
|
||||
HeadingLayout,
|
||||
type HeadingLayoutProps,
|
||||
} from "@opal/layouts/Content/HeadingLayout";
|
||||
import {
|
||||
ContentLg,
|
||||
type ContentLgProps,
|
||||
} from "@opal/layouts/Content/ContentLg";
|
||||
import {
|
||||
ContentMd,
|
||||
type ContentMdProps,
|
||||
} from "@opal/layouts/Content/ContentMd";
|
||||
LabelLayout,
|
||||
type LabelLayoutProps,
|
||||
} from "@opal/layouts/Content/LabelLayout";
|
||||
import type { TagProps } from "@opal/components/Tag/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { widthVariants, type WidthVariant } from "@opal/shared";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared types
|
||||
@@ -65,25 +62,14 @@ interface ContentBaseProps {
|
||||
// Discriminated union: valid sizePreset × variant combinations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type XlContentProps = ContentBaseProps & {
|
||||
type HeadingContentProps = ContentBaseProps & {
|
||||
/** Size preset. Default: `"headline"`. */
|
||||
sizePreset?: "headline" | "section";
|
||||
/** Variant. Default: `"heading"` for heading-eligible presets. */
|
||||
variant?: "heading";
|
||||
/** Optional secondary icon rendered in the icon row (ContentXl only). */
|
||||
moreIcon1?: IconFunctionComponent;
|
||||
/** Optional tertiary icon rendered in the icon row (ContentXl only). */
|
||||
moreIcon2?: IconFunctionComponent;
|
||||
variant?: "heading" | "section";
|
||||
};
|
||||
|
||||
type LgContentProps = ContentBaseProps & {
|
||||
/** Size preset. Default: `"headline"`. */
|
||||
sizePreset?: "headline" | "section";
|
||||
/** Variant. */
|
||||
variant: "section";
|
||||
};
|
||||
|
||||
type MdContentProps = ContentBaseProps & {
|
||||
type LabelContentProps = ContentBaseProps & {
|
||||
sizePreset: "main-content" | "main-ui" | "secondary";
|
||||
variant?: "section";
|
||||
/** When `true`, renders "(Optional)" beside the title in the muted font variant. */
|
||||
@@ -94,24 +80,20 @@ type MdContentProps = ContentBaseProps & {
|
||||
tag?: TagProps;
|
||||
};
|
||||
|
||||
/** ContentSm does not support descriptions or inline editing. */
|
||||
type SmContentProps = Omit<
|
||||
/** BodyLayout does not support descriptions or inline editing. */
|
||||
type BodyContentProps = Omit<
|
||||
ContentBaseProps,
|
||||
"description" | "editable" | "onTitleChange"
|
||||
> & {
|
||||
sizePreset: "main-content" | "main-ui" | "secondary";
|
||||
variant: "body";
|
||||
/** Layout orientation. Default: `"inline"`. */
|
||||
orientation?: ContentSmOrientation;
|
||||
orientation?: BodyOrientation;
|
||||
/** Title prominence. Default: `"default"`. */
|
||||
prominence?: ContentSmProminence;
|
||||
prominence?: BodyProminence;
|
||||
};
|
||||
|
||||
type ContentProps =
|
||||
| XlContentProps
|
||||
| LgContentProps
|
||||
| MdContentProps
|
||||
| SmContentProps;
|
||||
type ContentProps = HeadingContentProps | LabelContentProps | BodyContentProps;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Content — routes to the appropriate internal layout
|
||||
@@ -129,42 +111,34 @@ function Content(props: ContentProps) {
|
||||
|
||||
let layout: React.ReactNode = null;
|
||||
|
||||
// ContentXl / ContentLg: headline/section presets
|
||||
// Heading layout: headline/section presets with heading/section variant
|
||||
if (sizePreset === "headline" || sizePreset === "section") {
|
||||
if (variant === "heading") {
|
||||
layout = (
|
||||
<ContentXl
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<ContentXlProps, "sizePreset">)}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
layout = (
|
||||
<ContentLg
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<ContentLgProps, "sizePreset">)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ContentMd: main-content/main-ui/secondary with section variant
|
||||
else if (variant === "section" || variant === "heading") {
|
||||
layout = (
|
||||
<ContentMd
|
||||
<HeadingLayout
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<ContentMdProps, "sizePreset">)}
|
||||
variant={variant as HeadingLayoutProps["variant"]}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ContentSm: main-content/main-ui/secondary with body variant
|
||||
// Label layout: main-content/main-ui/secondary with section variant
|
||||
else if (variant === "section" || variant === "heading") {
|
||||
layout = (
|
||||
<LabelLayout
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<LabelLayoutProps, "sizePreset">)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Body layout: main-content/main-ui/secondary with body variant
|
||||
else if (variant === "body") {
|
||||
layout = (
|
||||
<ContentSm
|
||||
<BodyLayout
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<
|
||||
React.ComponentProps<typeof ContentSm>,
|
||||
React.ComponentProps<typeof BodyLayout>,
|
||||
"sizePreset"
|
||||
>)}
|
||||
/>
|
||||
@@ -193,8 +167,7 @@ export {
|
||||
type ContentProps,
|
||||
type SizePreset,
|
||||
type ContentVariant,
|
||||
type XlContentProps,
|
||||
type LgContentProps,
|
||||
type MdContentProps,
|
||||
type SmContentProps,
|
||||
type HeadingContentProps,
|
||||
type LabelContentProps,
|
||||
type BodyContentProps,
|
||||
};
|
||||
|
||||
@@ -1,145 +1,41 @@
|
||||
/* ===========================================================================
|
||||
Content — ContentXl
|
||||
/* ---------------------------------------------------------------------------
|
||||
Content — HeadingLayout
|
||||
|
||||
Icon row on top (flex-col). Icon row contains main icon + optional
|
||||
moreIcon1 / moreIcon2 in a flex-row.
|
||||
Two icon placement modes (driven by variant):
|
||||
left (variant="section") : flex-row — icon beside content
|
||||
top (variant="heading") : flex-col — icon above content
|
||||
|
||||
Sizing (icon size, gap, padding, font, line-height) is driven by the
|
||||
sizePreset prop via inline styles + Tailwind classes in the component.
|
||||
=========================================================================== */
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Layout — flex-col (icon row above body)
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl {
|
||||
@apply flex flex-col items-start;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Icon row — flex-row containing main icon + more icons
|
||||
Layout — icon placement
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-icon-row {
|
||||
@apply flex flex-row items-center;
|
||||
.opal-content-heading {
|
||||
@apply flex items-start;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Icons
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-icon-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
.opal-content-heading[data-icon-placement="left"] {
|
||||
@apply flex-row;
|
||||
}
|
||||
|
||||
.opal-content-xl-more-icon-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.opal-content-xl-icon {
|
||||
color: var(--text-04);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Body column
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-body {
|
||||
@apply flex flex-1 flex-col items-start;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Title row — title (or input) + edit button
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-title-row {
|
||||
@apply flex items-center w-full;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.opal-content-xl-title {
|
||||
@apply text-left overflow-hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-box-orient: vertical;
|
||||
-webkit-line-clamp: 1;
|
||||
padding: 0 0.125rem;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-xl-input-sizer {
|
||||
display: inline-grid;
|
||||
align-items: stretch;
|
||||
}
|
||||
|
||||
.opal-content-xl-input-sizer > * {
|
||||
grid-area: 1 / 1;
|
||||
padding: 0 0.125rem;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-xl-input-mirror {
|
||||
visibility: hidden;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.opal-content-xl-input {
|
||||
@apply bg-transparent outline-none border-none;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Edit button — visible only on hover of the outer container
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-edit-button {
|
||||
@apply opacity-0 transition-opacity shrink-0;
|
||||
}
|
||||
|
||||
.opal-content-xl:hover .opal-content-xl-edit-button {
|
||||
@apply opacity-100;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Description
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-xl-description {
|
||||
@apply text-left w-full;
|
||||
padding: 0 0.125rem;
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Content — ContentLg
|
||||
|
||||
Always inline (flex-row) — icon beside content.
|
||||
|
||||
Sizing (icon size, gap, padding, font, line-height) is driven by the
|
||||
sizePreset prop via inline styles + Tailwind classes in the component.
|
||||
=========================================================================== */
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Layout
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg {
|
||||
@apply flex flex-row items-start;
|
||||
.opal-content-heading[data-icon-placement="top"] {
|
||||
@apply flex-col;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Icon
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg-icon-container {
|
||||
.opal-content-heading-icon-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.opal-content-lg-icon {
|
||||
.opal-content-heading-icon {
|
||||
color: var(--text-04);
|
||||
}
|
||||
|
||||
@@ -147,7 +43,7 @@
|
||||
Body column
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg-body {
|
||||
.opal-content-heading-body {
|
||||
@apply flex flex-1 flex-col items-start;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
@@ -156,12 +52,12 @@
|
||||
Title row — title (or input) + edit button
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg-title-row {
|
||||
.opal-content-heading-title-row {
|
||||
@apply flex items-center w-full;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.opal-content-lg-title {
|
||||
.opal-content-heading-title {
|
||||
@apply text-left overflow-hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-box-orient: vertical;
|
||||
@@ -170,23 +66,23 @@
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-lg-input-sizer {
|
||||
.opal-content-heading-input-sizer {
|
||||
display: inline-grid;
|
||||
align-items: stretch;
|
||||
}
|
||||
|
||||
.opal-content-lg-input-sizer > * {
|
||||
.opal-content-heading-input-sizer > * {
|
||||
grid-area: 1 / 1;
|
||||
padding: 0 0.125rem;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-lg-input-mirror {
|
||||
.opal-content-heading-input-mirror {
|
||||
visibility: hidden;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.opal-content-lg-input {
|
||||
.opal-content-heading-input {
|
||||
@apply bg-transparent outline-none border-none;
|
||||
}
|
||||
|
||||
@@ -194,11 +90,11 @@
|
||||
Edit button — visible only on hover of the outer container
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg-edit-button {
|
||||
.opal-content-heading-edit-button {
|
||||
@apply opacity-0 transition-opacity shrink-0;
|
||||
}
|
||||
|
||||
.opal-content-lg:hover .opal-content-lg-edit-button {
|
||||
.opal-content-heading:hover .opal-content-heading-edit-button {
|
||||
@apply opacity-100;
|
||||
}
|
||||
|
||||
@@ -206,13 +102,13 @@
|
||||
Description
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-lg-description {
|
||||
.opal-content-heading-description {
|
||||
@apply text-left w-full;
|
||||
padding: 0 0.125rem;
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Content — ContentMd
|
||||
Content — LabelLayout
|
||||
|
||||
Always inline (flex-row). Icon color varies per sizePreset and is applied
|
||||
via Tailwind class from the component.
|
||||
@@ -222,7 +118,7 @@
|
||||
Layout
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md {
|
||||
.opal-content-label {
|
||||
@apply flex flex-row items-start;
|
||||
}
|
||||
|
||||
@@ -230,7 +126,7 @@
|
||||
Icon
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-icon-container {
|
||||
.opal-content-label-icon-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
@@ -240,7 +136,7 @@
|
||||
Body column
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-body {
|
||||
.opal-content-label-body {
|
||||
@apply flex flex-1 flex-col items-start;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
@@ -249,12 +145,12 @@
|
||||
Title row — title (or input) + edit button
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-title-row {
|
||||
.opal-content-label-title-row {
|
||||
@apply flex items-center w-full;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.opal-content-md-title {
|
||||
.opal-content-label-title {
|
||||
@apply text-left overflow-hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-box-orient: vertical;
|
||||
@@ -263,23 +159,23 @@
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-md-input-sizer {
|
||||
.opal-content-label-input-sizer {
|
||||
display: inline-grid;
|
||||
align-items: stretch;
|
||||
}
|
||||
|
||||
.opal-content-md-input-sizer > * {
|
||||
.opal-content-label-input-sizer > * {
|
||||
grid-area: 1 / 1;
|
||||
padding: 0 0.125rem;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
.opal-content-md-input-mirror {
|
||||
.opal-content-label-input-mirror {
|
||||
visibility: hidden;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.opal-content-md-input {
|
||||
.opal-content-label-input {
|
||||
@apply bg-transparent outline-none border-none;
|
||||
}
|
||||
|
||||
@@ -287,7 +183,7 @@
|
||||
Aux icon
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-aux-icon {
|
||||
.opal-content-label-aux-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
@@ -297,11 +193,11 @@
|
||||
Edit button — visible only on hover of the outer container
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-edit-button {
|
||||
.opal-content-label-edit-button {
|
||||
@apply opacity-0 transition-opacity shrink-0;
|
||||
}
|
||||
|
||||
.opal-content-md:hover .opal-content-md-edit-button {
|
||||
.opal-content-label:hover .opal-content-label-edit-button {
|
||||
@apply opacity-100;
|
||||
}
|
||||
|
||||
@@ -309,13 +205,13 @@
|
||||
Description
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-description {
|
||||
.opal-content-label-description {
|
||||
@apply text-left w-full;
|
||||
padding: 0 0.125rem;
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Content — ContentSm
|
||||
Content — BodyLayout
|
||||
|
||||
Three orientation modes (driven by orientation prop):
|
||||
inline : flex-row — icon left, title right
|
||||
@@ -330,19 +226,19 @@
|
||||
Layout — orientation
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-sm {
|
||||
.opal-content-body {
|
||||
@apply flex items-start;
|
||||
}
|
||||
|
||||
.opal-content-sm[data-orientation="inline"] {
|
||||
.opal-content-body[data-orientation="inline"] {
|
||||
@apply flex-row;
|
||||
}
|
||||
|
||||
.opal-content-sm[data-orientation="vertical"] {
|
||||
.opal-content-body[data-orientation="vertical"] {
|
||||
@apply flex-col;
|
||||
}
|
||||
|
||||
.opal-content-sm[data-orientation="reverse"] {
|
||||
.opal-content-body[data-orientation="reverse"] {
|
||||
@apply flex-row-reverse;
|
||||
}
|
||||
|
||||
@@ -350,7 +246,7 @@
|
||||
Icon
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-sm-icon-container {
|
||||
.opal-content-body-icon-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
@@ -360,7 +256,7 @@
|
||||
Title
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-sm-title {
|
||||
.opal-content-body-title {
|
||||
@apply text-left overflow-hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-box-orient: vertical;
|
||||
|
||||
@@ -8,7 +8,7 @@ Layout primitives for composing icon + title + description rows. These component
|
||||
|
||||
| Component | Description | Docs |
|
||||
|---|---|---|
|
||||
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`ContentLg`, `ContentMd`, or `ContentSm`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
|
||||
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`HeadingLayout`, `LabelLayout`, or `BodyLayout`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
|
||||
| [`ContentAction`](./ContentAction/README.md) | Wraps `Content` in a flex-row with an optional `rightChildren` slot for action buttons. Adds padding alignment via the shared `SizeVariant` scale. | [ContentAction README](./ContentAction/README.md) |
|
||||
|
||||
## Quick Start
|
||||
@@ -88,7 +88,6 @@ These are not exported — `Content` routes to them automatically:
|
||||
|
||||
| Layout | Used when | File |
|
||||
|---|---|---|
|
||||
| `ContentXl` | `sizePreset` is `headline` or `section` with `variant="heading"` | `Content/ContentXl.tsx` |
|
||||
| `ContentLg` | `sizePreset` is `headline` or `section` with `variant="section"` | `Content/ContentLg.tsx` |
|
||||
| `ContentMd` | `sizePreset` is `main-content`, `main-ui`, or `secondary` with `variant="section"` | `Content/ContentMd.tsx` |
|
||||
| `ContentSm` | `variant="body"` | `Content/ContentSm.tsx` |
|
||||
| `HeadingLayout` | `sizePreset` is `headline` or `section` | `Content/HeadingLayout.tsx` |
|
||||
| `LabelLayout` | `sizePreset` is `main-content`, `main-ui`, or `secondary` with `variant="section"` | `Content/LabelLayout.tsx` |
|
||||
| `BodyLayout` | `variant="body"` | `Content/BodyLayout.tsx` |
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
// - Interactive.Container (height + min-width + padding)
|
||||
// - Button (icon sizing)
|
||||
// - ContentAction (padding only)
|
||||
// - Content (ContentLg / ContentMd) (edit-button size)
|
||||
// - Content (HeadingLayout / LabelLayout) (edit-button size)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
|
||||
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
|
||||
import { getImageGenForm } from "./forms";
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
IMAGE_PROVIDER_GROUPS,
|
||||
ImageProvider,
|
||||
@@ -23,14 +23,13 @@ import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
data: llmProviderResponse,
|
||||
data: llmProviders = [],
|
||||
error: llmError,
|
||||
mutate: refetchProviders,
|
||||
} = useSWR<LLMProviderResponse<LLMProviderView>>(
|
||||
} = useSWR<LLMProviderView[]>(
|
||||
"/api/admin/llm/provider?include_image_gen=true",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const llmProviders = llmProviderResponse?.providers ?? [];
|
||||
|
||||
const {
|
||||
data: configs = [],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { FormikProps } from "formik";
|
||||
import { ImageProvider } from "../constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
ImageGenerationConfigView,
|
||||
ImageGenerationCredentials,
|
||||
|
||||
84
web/src/app/admin/configuration/llm/LLMConfiguration.tsx
Normal file
84
web/src/app/admin/configuration/llm/LLMConfiguration.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
"use client";
|
||||
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Title from "@/components/ui/title";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { LLMProviderView } from "./interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import { OpenAIForm } from "./forms/OpenAIForm";
|
||||
import { AnthropicForm } from "./forms/AnthropicForm";
|
||||
import { OllamaForm } from "./forms/OllamaForm";
|
||||
import { AzureForm } from "./forms/AzureForm";
|
||||
import { BedrockForm } from "./forms/BedrockForm";
|
||||
import { VertexAIForm } from "./forms/VertexAIForm";
|
||||
import { OpenRouterForm } from "./forms/OpenRouterForm";
|
||||
import { getFormForExistingProvider } from "./forms/getForm";
|
||||
import { CustomForm } from "./forms/CustomForm";
|
||||
|
||||
export function LLMConfiguration() {
|
||||
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (!existingLlmProviders) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const isFirstProvider = existingLlmProviders.length === 0;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title className="mb-2">Enabled LLM Providers</Title>
|
||||
|
||||
{existingLlmProviders.length > 0 ? (
|
||||
<>
|
||||
<Text as="p" className="mb-4">
|
||||
If multiple LLM providers are enabled, the default provider will be
|
||||
used for all "Default" Assistants. For user-created
|
||||
Assistants, you can select the LLM provider/model that best fits the
|
||||
use case!
|
||||
</Text>
|
||||
<div className="flex flex-col gap-y-4">
|
||||
{[...existingLlmProviders]
|
||||
.sort((a, b) => {
|
||||
if (a.is_default_provider && !b.is_default_provider) return -1;
|
||||
if (!a.is_default_provider && b.is_default_provider) return 1;
|
||||
return 0;
|
||||
})
|
||||
.map((llmProvider) => (
|
||||
<div key={llmProvider.id}>
|
||||
{getFormForExistingProvider(llmProvider)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<Callout type="warning" title="No LLM providers configured yet">
|
||||
Please set one up below in order to start using Onyx!
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<Title className="mb-2 mt-6">Add LLM Provider</Title>
|
||||
<Text as="p" className="mb-4">
|
||||
Add a new LLM provider by either selecting from one of the default
|
||||
providers or by specifying your own custom LLM provider.
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-col gap-y-4">
|
||||
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AzureForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
|
||||
|
||||
<CustomForm shouldMarkAsDefault={isFirstProvider} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { ArrayHelpers, FieldArray, FormikProps, useField } from "formik";
|
||||
import { ModelConfiguration } from "@/interfaces/llm";
|
||||
import { ModelConfiguration } from "./interfaces";
|
||||
import { ManualErrorMessage, TextFormField } from "@/components/Field";
|
||||
import { useEffect, useState } from "react";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
function ModelConfigurationRow({
|
||||
name,
|
||||
index,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
export const LLM_ADMIN_URL = "/api/admin/llm";
|
||||
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
|
||||
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
|
||||
"/api/admin/llm/provider-contextual-cost";
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -21,19 +21,15 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
export const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
|
||||
export function AnthropicModal({
|
||||
export function AnthropicForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Anthropic"
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -56,6 +52,7 @@ export function AnthropicModal({
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new Anthropic providers
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -28,7 +28,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
export const AZURE_PROVIDER_NAME = "azure";
|
||||
const AZURE_DISPLAY_NAME = "Microsoft Azure Cloud";
|
||||
|
||||
interface AzureModalValues extends BaseLLMFormValues {
|
||||
interface AzureFormValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
target_uri: string;
|
||||
api_base?: string;
|
||||
@@ -47,19 +47,15 @@ const buildTargetUri = (existingLlmProvider?: LLMProviderView): string => {
|
||||
return `${existingLlmProvider.api_base}/openai/deployments/${deploymentName}/chat/completions?api-version=${existingLlmProvider.api_version}`;
|
||||
};
|
||||
|
||||
export function AzureModal({
|
||||
export function AzureForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={AZURE_DISPLAY_NAME}
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -74,7 +70,7 @@ export function AzureModal({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: AzureModalValues = {
|
||||
const initialValues: AzureFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -101,7 +97,7 @@ export function AzureModal({
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
// Parse target_uri to extract api_base, api_version, and deployment_name
|
||||
let processedValues: AzureModalValues = { ...values };
|
||||
let processedValues: AzureFormValues = { ...values };
|
||||
|
||||
if (values.target_uri) {
|
||||
try {
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
} from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
|
||||
import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
@@ -65,7 +65,7 @@ const FIELD_AWS_ACCESS_KEY_ID = "custom_config.AWS_ACCESS_KEY_ID";
|
||||
const FIELD_AWS_SECRET_ACCESS_KEY = "custom_config.AWS_SECRET_ACCESS_KEY";
|
||||
const FIELD_AWS_BEARER_TOKEN_BEDROCK = "custom_config.AWS_BEARER_TOKEN_BEDROCK";
|
||||
|
||||
interface BedrockModalValues extends BaseLLMFormValues {
|
||||
interface BedrockFormValues extends BaseLLMFormValues {
|
||||
custom_config: {
|
||||
AWS_REGION_NAME: string;
|
||||
BEDROCK_AUTH_METHOD?: string;
|
||||
@@ -75,8 +75,8 @@ interface BedrockModalValues extends BaseLLMFormValues {
|
||||
};
|
||||
}
|
||||
|
||||
interface BedrockModalInternalsProps {
|
||||
formikProps: FormikProps<BedrockModalValues>;
|
||||
interface BedrockFormInternalsProps {
|
||||
formikProps: FormikProps<BedrockFormValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
@@ -87,7 +87,7 @@ interface BedrockModalInternalsProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
function BedrockModalInternals({
|
||||
function BedrockFormInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
@@ -97,7 +97,7 @@ function BedrockModalInternals({
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
}: BedrockModalInternalsProps) {
|
||||
}: BedrockFormInternalsProps) {
|
||||
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
|
||||
|
||||
// Clean up unused auth fields when tab changes
|
||||
@@ -258,11 +258,9 @@ function BedrockModalInternals({
|
||||
);
|
||||
}
|
||||
|
||||
export function BedrockModal({
|
||||
export function BedrockForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -270,8 +268,6 @@ export function BedrockModal({
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={BEDROCK_DISPLAY_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -286,7 +282,7 @@ export function BedrockModal({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: BedrockModalValues = {
|
||||
const initialValues: BedrockFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -356,7 +352,7 @@ export function BedrockModal({
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BedrockModalInternals
|
||||
<BedrockFormInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
@@ -7,7 +7,7 @@
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
|
||||
import { CustomModal } from "./CustomModal";
|
||||
import { CustomForm } from "./CustomForm";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
// Mock SWR's mutate function and useSWR
|
||||
@@ -116,10 +116,11 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
name: "My Custom Provider",
|
||||
provider: "openai",
|
||||
api_key: "test-key",
|
||||
default_model_name: "gpt-4",
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal />);
|
||||
render(<CustomForm />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "My Custom Provider",
|
||||
@@ -176,7 +177,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Invalid API key" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal />);
|
||||
render(<CustomForm />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "Bad Provider",
|
||||
@@ -223,13 +224,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
api_key: "old-key",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "claude-3-opus",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-3-opus",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
supports_image_input: null,
|
||||
},
|
||||
],
|
||||
custom_config: {},
|
||||
@@ -238,6 +239,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
groups: [],
|
||||
personas: [],
|
||||
deployment_name: null,
|
||||
is_default_provider: false,
|
||||
default_vision_model: null,
|
||||
is_default_vision_provider: null,
|
||||
};
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
@@ -252,7 +256,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ ...existingProvider, api_key: "new-key" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal existingLlmProvider={existingProvider} />);
|
||||
render(<CustomForm existingLlmProvider={existingProvider} />);
|
||||
|
||||
// For existing provider, click "Edit" button to open modal
|
||||
const editButton = screen.getByRole("button", { name: /edit/i });
|
||||
@@ -303,13 +307,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
api_key: "old-key",
|
||||
api_base: "https://example-openai-compatible.local/v1",
|
||||
api_version: "",
|
||||
default_model_name: "gpt-oss-20b-bw-failover",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gpt-oss-20b-bw-failover",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
supports_image_input: null,
|
||||
},
|
||||
],
|
||||
custom_config: {},
|
||||
@@ -318,6 +322,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
groups: [],
|
||||
personas: [],
|
||||
deployment_name: null,
|
||||
is_default_provider: false,
|
||||
default_vision_model: null,
|
||||
is_default_vision_provider: null,
|
||||
};
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
@@ -336,21 +343,19 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
name: "gpt-oss-20b-bw-failover",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
supports_image_input: null,
|
||||
},
|
||||
{
|
||||
name: "nemotron",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
supports_image_input: null,
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal existingLlmProvider={existingProvider} />);
|
||||
render(<CustomForm existingLlmProvider={existingProvider} />);
|
||||
|
||||
const editButton = screen.getByRole("button", { name: /edit/i });
|
||||
await user.click(editButton);
|
||||
@@ -418,7 +423,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal shouldMarkAsDefault={true} />);
|
||||
render(<CustomForm shouldMarkAsDefault={true} />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "New Default Provider",
|
||||
@@ -458,7 +463,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Database error" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal />);
|
||||
render(<CustomForm />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "Test Provider",
|
||||
@@ -494,7 +499,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal />);
|
||||
render(<CustomForm />);
|
||||
|
||||
// Open modal
|
||||
const openButton = screen.getByRole("button", {
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
Formik,
|
||||
ErrorMessage,
|
||||
} from "formik";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
@@ -21,7 +21,7 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { ModelConfigurationField } from "@/app/admin/configuration/llm/ModelConfigurationField";
|
||||
import { ModelConfigurationField } from "../ModelConfigurationField";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
@@ -38,11 +38,9 @@ function customConfigProcessing(customConfigsList: [string, string][]) {
|
||||
return customConfig;
|
||||
}
|
||||
|
||||
export function CustomModal({
|
||||
export function CustomForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
@@ -50,8 +48,6 @@ export function CustomModal({
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
buttonMode={!existingLlmProvider}
|
||||
buttonText="Add Custom LLM Provider"
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -72,15 +68,7 @@ export function CustomModal({
|
||||
...modelConfiguration,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
})
|
||||
) ?? [
|
||||
{
|
||||
name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
) ?? [{ name: "", is_visible: true, max_input_tokens: null }],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
@@ -124,8 +112,7 @@ export function CustomModal({
|
||||
name: mc.name,
|
||||
is_visible: mc.is_visible,
|
||||
max_input_tokens: mc.max_input_tokens ?? null,
|
||||
supports_image_input: mc.supports_image_input ?? false,
|
||||
supports_reasoning: mc.supports_reasoning ?? false,
|
||||
supports_image_input: null,
|
||||
}))
|
||||
.filter(
|
||||
(mc) => mc.name === values.default_model_name || mc.is_visible
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
} from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -24,20 +24,20 @@ import {
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import { fetchOllamaModels } from "../utils";
|
||||
|
||||
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
|
||||
interface OllamaModalValues extends BaseLLMFormValues {
|
||||
interface OllamaFormValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface OllamaModalContentProps {
|
||||
formikProps: FormikProps<OllamaModalValues>;
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
@@ -48,7 +48,7 @@ interface OllamaModalContentProps {
|
||||
isFormValid: boolean;
|
||||
}
|
||||
|
||||
function OllamaModalContent({
|
||||
function OllamaFormContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
@@ -58,7 +58,7 @@ function OllamaModalContent({
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaModalContentProps) {
|
||||
}: OllamaFormContentProps) {
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -131,11 +131,9 @@ function OllamaModalContent({
|
||||
);
|
||||
}
|
||||
|
||||
export function OllamaModal({
|
||||
export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -143,8 +141,6 @@ export function OllamaModal({
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -159,7 +155,7 @@ export function OllamaModal({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: OllamaModalValues = {
|
||||
const initialValues: OllamaFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -216,7 +212,7 @@ export function OllamaModal({
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OllamaModalContent
|
||||
<OllamaFormContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
@@ -19,19 +19,15 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
export const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
|
||||
export function OpenAIModal({
|
||||
export function OpenAIForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="OpenAI"
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -53,6 +49,7 @@ export function OpenAIModal({
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new OpenAI providers
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
ModelConfiguration,
|
||||
OpenRouterModelResponse,
|
||||
} from "@/interfaces/llm";
|
||||
} from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -32,7 +32,7 @@ const OPENROUTER_DISPLAY_NAME = "OpenRouter";
|
||||
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
|
||||
const OPENROUTER_MODELS_API_URL = "/api/admin/llm/openrouter/available-models";
|
||||
|
||||
interface OpenRouterModalValues extends BaseLLMFormValues {
|
||||
interface OpenRouterFormValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
@@ -80,7 +80,6 @@ async function fetchOpenRouterModels(params: {
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
@@ -91,11 +90,9 @@ async function fetchOpenRouterModels(params: {
|
||||
}
|
||||
}
|
||||
|
||||
export function OpenRouterModal({
|
||||
export function OpenRouterForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -104,8 +101,6 @@ export function OpenRouterModal({
|
||||
providerName={OPENROUTER_DISPLAY_NAME}
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -120,7 +115,7 @@ export function OpenRouterModal({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: OpenRouterModalValues = {
|
||||
const initialValues: OpenRouterFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { TextFormField, FileUploadFormField } from "@/components/Field";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -25,26 +25,22 @@ const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
interface VertexAIModalValues extends BaseLLMFormValues {
|
||||
interface VertexAIFormValues extends BaseLLMFormValues {
|
||||
custom_config: {
|
||||
vertex_credentials: string;
|
||||
vertex_location: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function VertexAIModal({
|
||||
export function VertexAIForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -59,12 +55,13 @@ export function VertexAIModal({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: VertexAIModalValues = {
|
||||
const initialValues: VertexAIFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
),
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
// Default to auto mode for new Vertex AI providers
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
|
||||
import { ModelConfiguration, SimpleKnownModel } from "../../interfaces";
|
||||
import { FormikProps } from "formik";
|
||||
import { BaseLLMFormValues } from "../formUtils";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useEffect } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import { ModelConfiguration } from "@/interfaces/llm";
|
||||
import { ModelConfiguration } from "../../interfaces";
|
||||
|
||||
interface FetchModelsButtonProps {
|
||||
onFetch: () => Promise<{ models: ModelConfiguration[]; error?: string }>;
|
||||
@@ -2,9 +2,8 @@ import { LoadingAnimation } from "@/components/Loading";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgTrash } from "@opal/icons";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { deleteLlmProvider } from "@/lib/llmConfig/svc";
|
||||
import { LLMProviderView } from "../../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
|
||||
interface FormActionButtonsProps {
|
||||
isTesting: boolean;
|
||||
@@ -26,14 +25,41 @@ export function FormActionButtons({
|
||||
const handleDelete = async () => {
|
||||
if (!existingLlmProvider) return;
|
||||
|
||||
try {
|
||||
await deleteLlmProvider(existingLlmProvider.id);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
alert(`Failed to delete provider: ${message}`);
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the deleted provider was the default, set the first remaining provider as default
|
||||
if (existingLlmProvider.is_default_provider) {
|
||||
const remainingProvidersResponse = await fetch(LLM_PROVIDERS_ADMIN_URL);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders = await remainingProvidersResponse.json();
|
||||
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
console.error("Failed to set new default provider");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -6,7 +6,7 @@ import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
} from "../../interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
@@ -14,8 +14,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
|
||||
export interface ProviderFormContext {
|
||||
onClose: () => void;
|
||||
@@ -36,10 +35,6 @@ interface ProviderFormEntrypointWrapperProps {
|
||||
buttonMode?: boolean;
|
||||
/** Custom button text for buttonMode (defaults to "Add {providerName}") */
|
||||
buttonText?: string;
|
||||
/** Controlled open state — when defined, the wrapper renders only a modal (no card/button UI) */
|
||||
open?: boolean;
|
||||
/** Callback when controlled modal requests close */
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export function ProviderFormEntrypointWrapper({
|
||||
@@ -49,11 +44,8 @@ export function ProviderFormEntrypointWrapper({
|
||||
existingLlmProvider,
|
||||
buttonMode,
|
||||
buttonText,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: ProviderFormEntrypointWrapperProps) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
const isControlled = open !== undefined;
|
||||
|
||||
// Shared hooks
|
||||
const { mutate } = useSWRConfig();
|
||||
@@ -62,45 +54,33 @@ export function ProviderFormEntrypointWrapper({
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
// Suppress SWR when controlled + closed to avoid unnecessary API calls
|
||||
const swrKey =
|
||||
providerEndpoint && !(isControlled && !open)
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null;
|
||||
|
||||
// Fetch model configurations for this provider
|
||||
const { data: wellKnownLLMProvider } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
swrKey,
|
||||
providerEndpoint
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const onClose = () => {
|
||||
if (isControlled) {
|
||||
onOpenChange?.(false);
|
||||
} else {
|
||||
setFormIsVisible(false);
|
||||
}
|
||||
};
|
||||
const onClose = () => setFormIsVisible(false);
|
||||
|
||||
async function handleSetAsDefault(): Promise<void> {
|
||||
if (!existingLlmProvider) return;
|
||||
|
||||
const firstVisibleModel = existingLlmProvider.model_configurations.find(
|
||||
(m) => m.is_visible
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!firstVisibleModel) {
|
||||
toast.error("No visible models available for this provider.");
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await setDefaultLlmModel(existingLlmProvider.id, firstVisibleModel.name);
|
||||
await mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
toast.success("Provider set as default successfully!");
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
toast.error(`Failed to set provider as default: ${message}`);
|
||||
}
|
||||
await mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
toast.success("Provider set as default successfully!");
|
||||
}
|
||||
|
||||
const context: ProviderFormContext = {
|
||||
@@ -113,31 +93,6 @@ export function ProviderFormEntrypointWrapper({
|
||||
wellKnownLLMProvider,
|
||||
};
|
||||
|
||||
const defaultTitle = `${existingLlmProvider ? "Configure" : "Setup"} ${
|
||||
existingLlmProvider?.name ? `"${existingLlmProvider.name}"` : providerName
|
||||
}`;
|
||||
|
||||
function renderModal(isVisible: boolean, title?: string) {
|
||||
if (!isVisible) return null;
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={title ?? defaultTitle}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
// Controlled mode: render nothing when closed, render only modal when open
|
||||
if (isControlled) {
|
||||
return renderModal(!!open);
|
||||
}
|
||||
|
||||
// Button mode: simple button that opens a modal
|
||||
if (buttonMode && !existingLlmProvider) {
|
||||
return (
|
||||
@@ -145,7 +100,19 @@ export function ProviderFormEntrypointWrapper({
|
||||
<Button action onClick={() => setFormIsVisible(true)}>
|
||||
{buttonText ?? `Add ${providerName}`}
|
||||
</Button>
|
||||
{renderModal(formIsVisible, `Setup ${providerName}`)}
|
||||
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`Setup ${providerName}`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -168,18 +135,24 @@ export function ProviderFormEntrypointWrapper({
|
||||
<Text as="p" secondaryBody text03 className="italic">
|
||||
({providerName})
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
{!existingLlmProvider.is_default_provider && (
|
||||
<Text
|
||||
as="p"
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{existingLlmProvider && (
|
||||
<div className="my-auto ml-3">
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
{existingLlmProvider.is_default_provider ? (
|
||||
<Badge variant="agent">Default</Badge>
|
||||
) : (
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -209,7 +182,22 @@ export function ProviderFormEntrypointWrapper({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{renderModal(formIsVisible)}
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
|
||||
existingLlmProvider?.name
|
||||
? `"${existingLlmProvider.name}"`
|
||||
: providerName
|
||||
}`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -2,11 +2,8 @@ import {
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
} from "../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../constants";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
@@ -18,7 +15,10 @@ export const buildDefaultInitialValues = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[]
|
||||
) => {
|
||||
const defaultModelName = modelConfigurations?.[0]?.name ?? "";
|
||||
const defaultModelName =
|
||||
existingLlmProvider?.default_model_name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
@@ -119,7 +119,6 @@ export const filterModelConfigurations = (
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
@@ -142,7 +141,6 @@ export const getAutoModeModelConfigurations = (
|
||||
is_visible: modelConfiguration.is_visible,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
);
|
||||
@@ -225,8 +223,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: finalDefaultModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
@@ -251,7 +247,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
@@ -267,16 +262,12 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: finalDefaultModelName,
|
||||
}),
|
||||
});
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
44
web/src/app/admin/configuration/llm/forms/getForm.tsx
Normal file
44
web/src/app/admin/configuration/llm/forms/getForm.tsx
Normal file
@@ -0,0 +1,44 @@
|
||||
import { LLMProviderName, LLMProviderView } from "../interfaces";
|
||||
import { AnthropicForm } from "./AnthropicForm";
|
||||
import { OpenAIForm } from "./OpenAIForm";
|
||||
import { OllamaForm } from "./OllamaForm";
|
||||
import { AzureForm } from "./AzureForm";
|
||||
import { VertexAIForm } from "./VertexAIForm";
|
||||
import { OpenRouterForm } from "./OpenRouterForm";
|
||||
import { CustomForm } from "./CustomForm";
|
||||
import { BedrockForm } from "./BedrockForm";
|
||||
|
||||
export function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
provider.provider === LLMProviderName.OPENAI &&
|
||||
provider.api_key &&
|
||||
!provider.api_base &&
|
||||
Object.keys(provider.custom_config || {}).length === 0
|
||||
);
|
||||
}
|
||||
|
||||
export const getFormForExistingProvider = (provider: LLMProviderView) => {
|
||||
switch (provider.provider) {
|
||||
case LLMProviderName.OPENAI:
|
||||
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
|
||||
if (detectIfRealOpenAIProvider(provider)) {
|
||||
return <OpenAIForm existingLlmProvider={provider} />;
|
||||
} else {
|
||||
return <CustomForm existingLlmProvider={provider} />;
|
||||
}
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterForm existingLlmProvider={provider} />;
|
||||
default:
|
||||
return <CustomForm existingLlmProvider={provider} />;
|
||||
}
|
||||
};
|
||||
@@ -13,8 +13,8 @@ export interface ModelConfiguration {
|
||||
name: string;
|
||||
is_visible: boolean;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
supports_image_input: boolean | null;
|
||||
supports_reasoning?: boolean;
|
||||
display_name?: string;
|
||||
provider_display_name?: string;
|
||||
vendor?: string;
|
||||
@@ -30,6 +30,7 @@ export interface SimpleKnownModel {
|
||||
export interface WellKnownLLMProviderDescriptor {
|
||||
name: string;
|
||||
known_models: ModelConfiguration[];
|
||||
|
||||
recommended_default_model: SimpleKnownModel | null;
|
||||
}
|
||||
|
||||
@@ -39,31 +40,44 @@ export interface LLMModelDescriptor {
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
export interface LLMProviderView {
|
||||
id: number;
|
||||
export interface LLMProvider {
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
api_base: string | null;
|
||||
api_version: string | null;
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
deployment_name: string | null;
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
id: number;
|
||||
is_default_provider: boolean | null;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends LLMProviderView {
|
||||
vision_models: string[];
|
||||
}
|
||||
|
||||
export interface LLMProviderDescriptor {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
provider_display_name: string;
|
||||
provider_display_name?: string;
|
||||
default_model_name: string;
|
||||
is_default_provider: boolean | null;
|
||||
is_default_vision_provider?: boolean | null;
|
||||
default_vision_model?: string | null;
|
||||
is_public?: boolean;
|
||||
groups?: number[];
|
||||
personas?: number[];
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
@@ -88,22 +102,9 @@ export interface BedrockModelResponse {
|
||||
supports_image_input: boolean;
|
||||
}
|
||||
|
||||
export interface DefaultModel {
|
||||
provider_id: number;
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface LLMProviderResponse<T> {
|
||||
providers: T[];
|
||||
default_text: DefaultModel | null;
|
||||
default_vision: DefaultModel | null;
|
||||
}
|
||||
|
||||
export interface LLMProviderFormProps {
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
@@ -1,7 +1,14 @@
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { LLMConfiguration } from "./LLMConfiguration";
|
||||
import { SvgCpu } from "@opal/icons";
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
|
||||
|
||||
<LLMConfiguration />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user