mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-07 08:35:47 +00:00
Compare commits
51 Commits
fix/remove
...
arg_packet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fff5f519bc | ||
|
|
95a47bdd97 | ||
|
|
1ec871b324 | ||
|
|
27179af627 | ||
|
|
3b0ec64f91 | ||
|
|
43bbef162e | ||
|
|
bd15ff7b4e | ||
|
|
d29791ac9c | ||
|
|
755b874bba | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
3061fd7ab5 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
17e1b27434 | ||
|
|
3f4a6840b1 | ||
|
|
9a4e82b91e | ||
|
|
c20cc9da19 | ||
|
|
b1f08f3964 | ||
|
|
037540225c | ||
|
|
1be71fd2af | ||
|
|
8bef7ab8fb | ||
|
|
8ca7c1af5a | ||
|
|
089fc6ed3e | ||
|
|
a25079ac23 | ||
|
|
00dca7c3ec | ||
|
|
a8c7b322cb | ||
|
|
e549944bce | ||
|
|
723637f379 | ||
|
|
e9913876c0 | ||
|
|
20948e2ea3 | ||
|
|
5dcbc91643 | ||
|
|
c8e874df49 | ||
|
|
b1a6a08eed | ||
|
|
ec3e571a7f | ||
|
|
b03a0f8cac | ||
|
|
8163ca704a | ||
|
|
b7abf3991a |
@@ -106,13 +106,34 @@ onyx-cli ask --json "What authentication methods do we support?"
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation with `citation_number` and `document_id` |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
@@ -129,6 +150,10 @@ Uses a specific Onyx agent/persona instead of the default.
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
39
.github/workflows/release-cli.yml
vendored
Normal file
39
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -22,12 +22,10 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -54,6 +55,7 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -1009,6 +1011,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1215,7 +1218,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -270,34 +270,10 @@ def upsert_llm_provider(
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
|
||||
# Build a lookup of requested visibility by model name
|
||||
requested_visibility = {
|
||||
mc.name: mc.is_visible
|
||||
for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
removed_ids = [
|
||||
mc.id for name, mc in existing_by_name.items() if name not in models_to_exist
|
||||
]
|
||||
|
||||
# Prevent removing and hiding the default model
|
||||
if default_model:
|
||||
for name, mc in existing_by_name.items():
|
||||
if mc.id == default_model.id:
|
||||
if name not in models_to_exist:
|
||||
raise ValueError(
|
||||
f"Cannot remove the default model '{name}'. "
|
||||
"Please change the default model before removing."
|
||||
)
|
||||
if not requested_visibility.get(name, True):
|
||||
raise ValueError(
|
||||
f"Cannot hide the default model '{name}'. "
|
||||
"Please change the default model before hiding."
|
||||
)
|
||||
|
||||
if removed_ids:
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id.in_(removed_ids)
|
||||
@@ -562,6 +538,7 @@ def fetch_default_model(
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
|
||||
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
103
backend/onyx/document_index/FILTER_SEMANTICS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -698,41 +698,6 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -758,41 +723,53 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -37,30 +34,22 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -73,16 +62,12 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -104,8 +89,7 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -114,16 +98,14 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -135,8 +117,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -147,8 +129,7 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -160,73 +141,94 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# filter_str += _build_kg_filter(
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# )
|
||||
# ))
|
||||
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# )
|
||||
# ))
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -1512,6 +1512,10 @@
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-6": {
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1526,6 +1530,10 @@
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-6": {
|
||||
"display_name": "Claude Sonnet 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-sonnet-4-5-20250929": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -1,37 +1,8 @@
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
|
||||
def _fallback_bedrock_regions() -> list[str]:
|
||||
# Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available.
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-gov-east-1",
|
||||
"us-gov-west-1",
|
||||
"us-west-2",
|
||||
"ap-northeast-1",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ap-east-1",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-west-2",
|
||||
]
|
||||
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
@@ -51,13 +22,6 @@ OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
@@ -65,13 +29,6 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
AWS_REGION_NAME_KWARG = "aws_region_name"
|
||||
AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME"
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4-6",
|
||||
"display_name": "Claude Sonnet 4.6"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-5",
|
||||
"display_name": "Claude Opus 4.5"
|
||||
|
||||
@@ -6758,12 +6758,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz",
|
||||
"integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.0.1"
|
||||
"ip-address": "10.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
@@ -7556,9 +7556,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
|
||||
@@ -65,7 +65,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -446,17 +445,16 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
# When transitioning to auto mode, preserve existing model configurations
|
||||
# so the upsert doesn't try to delete them (which would trip the default
|
||||
# model protection guard). sync_auto_mode_models will handle the model
|
||||
# lifecycle afterward — adding new models, hiding removed ones, and
|
||||
# updating the default. This is safe even if sync fails: the provider
|
||||
# keeps its old models and default rather than losing them.
|
||||
if transitioning_to_auto_mode and existing_provider:
|
||||
llm_provider_upsert_request.model_configurations = [
|
||||
ModelConfigurationUpsertRequest.from_model(mc)
|
||||
for mc in existing_provider.model_configurations
|
||||
]
|
||||
# Before the upsert, check if this provider currently owns the global
|
||||
# CHAT default. The upsert may cascade-delete model_configurations
|
||||
# (and their flow mappings), so we need to remember this beforehand.
|
||||
was_default_provider = False
|
||||
if existing_provider and transitioning_to_auto_mode:
|
||||
current_default = fetch_default_llm_model(db_session)
|
||||
was_default_provider = (
|
||||
current_default is not None
|
||||
and current_default.llm_provider_id == existing_provider.id
|
||||
)
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
@@ -470,6 +468,7 @@ def put_llm_provider(
|
||||
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
@@ -479,6 +478,20 @@ def put_llm_provider(
|
||||
updated_provider,
|
||||
config,
|
||||
)
|
||||
|
||||
# If this provider was the default before the transition,
|
||||
# restore the default using the recommended model.
|
||||
if was_default_provider:
|
||||
recommended = config.get_default_model(
|
||||
llm_provider_upsert_request.provider
|
||||
)
|
||||
if recommended:
|
||||
update_default_provider(
|
||||
provider_id=updated_provider.id,
|
||||
model_name=recommended.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Refresh result with synced models
|
||||
result = LLMProviderView.from_model(updated_provider)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -259,6 +260,15 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# File Reader Packets
|
||||
################################################
|
||||
@@ -379,6 +389,7 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -56,3 +56,23 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,3 +92,7 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -376,3 +376,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
17
backend/onyx/utils/jsonriver/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
419
backend/onyx/utils/jsonriver/parse.py
Normal file
419
backend/onyx/utils/jsonriver/parse.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if prev_val and cur_val[len(prev_val) - 1] != prev_val[-1]:
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if prev and current[len(prev) - 1] != prev[-1]:
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
518
backend/onyx/utils/jsonriver/tokenize.py
Normal file
518
backend/onyx/utils/jsonriver/tokenize.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
self.more_content_expected = True
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.more_content_expected = False
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
self.input.more_content_expected = False
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
self.input.more_content_expected = True
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -1152,179 +1152,3 @@ class TestAutoModeTransitionsAndResync:
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_updates_default_when_recommended_default_changes(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and a sync arrives with a
|
||||
different recommended default model (both models still in config),
|
||||
the global default should be updated to the new recommendation.
|
||||
|
||||
Steps:
|
||||
1. Create auto-mode provider with config v1: default=gpt-4o.
|
||||
2. Set gpt-4o as the global CHAT default.
|
||||
3. Re-sync with config v2: default=gpt-4o-mini (gpt-4o still present).
|
||||
4. Verify the CHAT default switched to gpt-4o-mini and both models
|
||||
remain visible.
|
||||
"""
|
||||
config_v1 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
config_v2 = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o-mini",
|
||||
additional_models=["gpt-4o"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config_v1,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o as the global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
default_before = fetch_default_llm_model(db_session)
|
||||
assert default_before is not None
|
||||
assert default_before.name == "gpt-4o"
|
||||
|
||||
# Re-sync with config v2 (recommended default changed)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config_v2,
|
||||
)
|
||||
assert changes > 0, "Sync should report changes when default switches"
|
||||
|
||||
# Both models should remain visible
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
visibility = {
|
||||
mc.name: mc.is_visible for mc in provider.model_configurations
|
||||
}
|
||||
assert visibility["gpt-4o"] is True
|
||||
assert visibility["gpt-4o-mini"] is True
|
||||
|
||||
# The CHAT default should now be gpt-4o-mini
|
||||
default_after = fetch_default_llm_model(db_session)
|
||||
assert default_after is not None
|
||||
assert (
|
||||
default_after.name == "gpt-4o-mini"
|
||||
), f"Default should be updated to 'gpt-4o-mini', got '{default_after.name}'"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
def test_sync_idempotent_when_default_already_matches(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""When the provider owns the CHAT default and it already matches the
|
||||
recommended default, re-syncing should report zero changes.
|
||||
|
||||
This is a regression test for the bug where changes was unconditionally
|
||||
incremented even when the default was already correct.
|
||||
"""
|
||||
config = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
return_value=config,
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set gpt-4o (the recommended default) as global CHAT default
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# First sync to stabilize state
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
|
||||
# Second sync — default already matches, should be a no-op
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
changes = sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=config,
|
||||
)
|
||||
assert changes == 0, (
|
||||
f"Expected 0 changes when default already matches recommended, "
|
||||
f"got {changes}"
|
||||
)
|
||||
|
||||
# Default should still be gpt-4o
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.name == "gpt-4o"
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
This should act as the main point of reference for testing that default model
|
||||
logic is consisten.
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
|
||||
def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
models: list[ModelConfigurationUpsertRequest] | None = None,
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider with multiple models."""
|
||||
if models is None:
|
||||
models = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True, supports_image_input=False
|
||||
),
|
||||
]
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=models,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_provider(db_session: Session, name: str) -> None:
|
||||
"""Helper to clean up a test provider by name."""
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if provider:
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_name(db_session: Session) -> Generator[str, None, None]:
|
||||
"""Generate a unique provider name for each test, with automatic cleanup."""
|
||||
name = f"test-provider-{uuid4().hex[:8]}"
|
||||
yield name
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, name)
|
||||
|
||||
|
||||
class TestDefaultModelProtection:
|
||||
"""Tests that the default model cannot be removed or hidden."""
|
||||
|
||||
def test_cannot_remove_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default text model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to update the provider without the default model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_hide_default_text_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Setting is_visible=False on the default text model should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to hide the default model
|
||||
with pytest.raises(ValueError, match="Cannot hide the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=False
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_cannot_remove_default_vision_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing the default vision model from a provider should raise ValueError."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
# Set gpt-4o as both the text and vision default
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
update_default_vision_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Try to remove the default vision model
|
||||
with pytest.raises(ValueError, match="Cannot remove the default model"):
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def test_can_remove_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Removing a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Remove gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_names = {mc.name for mc in updated.model_configurations}
|
||||
assert "gpt-4o" in model_names
|
||||
assert "gpt-4o-mini" not in model_names
|
||||
|
||||
def test_can_hide_non_default_model(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""Hiding a non-default model should succeed."""
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
update_default_provider(provider.id, "gpt-4o", db_session)
|
||||
|
||||
# Hide gpt-4o-mini (not default) — should succeed
|
||||
updated = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o", is_visible=True, supports_image_input=True
|
||||
),
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=False
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
model_visibility = {
|
||||
mc.name: mc.is_visible for mc in updated.model_configurations
|
||||
}
|
||||
assert model_visibility["gpt-4o"] is True
|
||||
assert model_visibility["gpt-4o-mini"] is False
|
||||
@@ -950,6 +950,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
@@ -1294,9 +1295,18 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type="python",
|
||||
argument_deltas={"code": code},
|
||||
),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=False,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
|
||||
630
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
630
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
@@ -0,0 +1,630 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _make_tool_call_delta(
|
||||
index: int = 0,
|
||||
name: str | None = None,
|
||||
arguments: str | None = None,
|
||||
function_is_none: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
|
||||
delta = MagicMock()
|
||||
delta.index = index
|
||||
if function_is_none:
|
||||
delta.function = None
|
||||
else:
|
||||
delta.function = MagicMock()
|
||||
delta.function.name = name
|
||||
delta.function.arguments = arguments
|
||||
return delta
|
||||
|
||||
|
||||
def _make_placement() -> Placement:
|
||||
return Placement(turn_index=0, tab_index=0)
|
||||
|
||||
|
||||
def _mock_tool_class(emit: bool = True) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
cls.should_emit_argument_deltas.return_value = emit
|
||||
return cls
|
||||
|
||||
|
||||
def _collect(
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
delta: MagicMock,
|
||||
placement: Placement | None = None,
|
||||
parsers: dict[int, Parser] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Run maybe_emit_argument_delta and return the yielded packets."""
|
||||
return list(
|
||||
maybe_emit_argument_delta(
|
||||
tc_map,
|
||||
delta,
|
||||
placement or _make_placement(),
|
||||
parsers if parsers is not None else {},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _stream_fragments(
|
||||
fragments: list[str],
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
placement: Placement | None = None,
|
||||
) -> list[str]:
|
||||
"""Feed fragments into maybe_emit_argument_delta one by one, returning
|
||||
all emitted content values concatenated per-key as a flat list."""
|
||||
pl = placement or _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
emitted: list[str] = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
delta = _make_tool_call_delta(arguments=frag)
|
||||
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
|
||||
obj = packet.obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
for value in obj.argument_deltas.values():
|
||||
emitted.append(value)
|
||||
return emitted
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaGuards:
|
||||
"""Tests for conditions that cause no packet to be emitted."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_does_not_opt_in(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Tools that return False from should_emit_argument_deltas emit nothing."""
|
||||
mock_get_tool.return_value = _mock_tool_class(emit=False)
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_class_unknown(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = None
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_no_argument_fragment(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_key_value_not_started(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Key exists in JSON but its string value hasn't begun yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Only the opening brace has arrived — no key to stream yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": "{"}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaBasic:
|
||||
"""Tests for correct packet content and incremental emission."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "', "print(1)", '"}']
|
||||
|
||||
pl = _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
all_packets = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
|
||||
)
|
||||
all_packets.extend(packets)
|
||||
|
||||
assert len(all_packets) >= 1
|
||||
# Verify packet structure
|
||||
obj = all_packets[0].obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
assert obj.tool_type == "python"
|
||||
# All emitted content should reconstruct the value
|
||||
full_code = ""
|
||||
for p in all_packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
if "code" in p.obj.argument_deltas:
|
||||
full_code += p.obj.argument_deltas["code"]
|
||||
assert full_code == "print(1)"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_only_new_content_on_subsequent_call(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""After a first emission, subsequent calls emit only the diff."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# First fragment opens the string
|
||||
tc_map[0]["arguments"] = '{"code": "abc'
|
||||
packets_1 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "abc"
|
||||
|
||||
# Second fragment appends more
|
||||
tc_map[0]["arguments"] = '{"code": "abcdef'
|
||||
packets_2 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
|
||||
)
|
||||
code_2 = ""
|
||||
for p in packets_2:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_2 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_2 == "def"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""When a second key starts, emissions switch to that key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'", "output": "hello',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "x" in full
|
||||
assert "hello" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains the end of one value and the start of the next key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'y", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "xy" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An empty string value has nothing to emit."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# Opening quote just arrived, value is empty
|
||||
tc_map[0]["arguments"] = '{"code": "'
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
|
||||
# No string content yet, so either no packet or empty deltas
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
assert p.obj.argument_deltas.get("code", "") == ""
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaDecoding:
|
||||
"""Tests verifying that JSON escape sequences are properly decoded."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "line1\\nline2"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "line1\nline2"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\tindented"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "\tindented"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "say \\"hi\\""}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == 'say "hi"'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "path\\\\dir"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "path\\dir"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\u0041"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "A"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_escape_at_end_decoded_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\', 'n"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "hello\n"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_unicode_escape_completed_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A partial \\uXX sequence is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\u00', '41"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "helloA"
|
||||
|
||||
|
||||
class TestArgumentDeltaStreamingE2E:
|
||||
"""Simulates realistic sequences of LLM argument deltas to verify
|
||||
the full pipeline produces correct decoded output."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"',
|
||||
"code",
|
||||
'": "',
|
||||
"print(",
|
||||
"'hello')",
|
||||
"\\n",
|
||||
"print(",
|
||||
"'world')",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "print('hello')\nprint('world')"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code with tabs and newlines."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"if True:",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"pass",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "if True:\n\tpass"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An escape sequence split across two fragments (backslash in one,
|
||||
'n' in the next) should still decode correctly."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "hello',
|
||||
"\\",
|
||||
"n",
|
||||
'world"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "hello\nworld"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams a multi-line function with multiple escape sequences."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"def foo():",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"x = 1",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"return x",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "def foo():\n\tx = 1\n\treturn x"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code first, then a second key (language) — both decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = 1",
|
||||
'", "language": "',
|
||||
"python",
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
# Should have emissions for both keys
|
||||
full = "".join(emitted)
|
||||
assert "x = 1" in full
|
||||
assert "python" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
|
||||
The escaped quotes inside the *outer* JSON value should prevent the
|
||||
inner `"key":` from being mistaken for a top-level JSON key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
|
||||
# The inner quotes are escaped as \" in the JSON value.
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = {",
|
||||
'\\"key\\"',
|
||||
": ",
|
||||
'\\"val\\"',
|
||||
"}",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'x = {"key": "val"}'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Colons inside the string value should not confuse key detection."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"url = ",
|
||||
'\\"https://example.com\\"',
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'url = "https://example.com"'
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaEdgeCases:
|
||||
"""Edge cases not covered by the standard test classes."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Some delta chunks have function=None (e.g. role-only deltas)."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
|
||||
assert _collect(tc_map, delta) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Two tool calls streaming at different indices in parallel."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""},
|
||||
1: {"id": "tc_2", "name": "python", "arguments": ""},
|
||||
}
|
||||
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# Feed full JSON to index 0
|
||||
tc_map[0]["arguments"] = '{"code": "aaa"}'
|
||||
packets_0 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_0 = ""
|
||||
for p in packets_0:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_0 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_0 == "aaa"
|
||||
|
||||
# Feed full JSON to index 1
|
||||
tc_map[1]["arguments"] = '{"code": "bbb"}'
|
||||
packets_1 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "bbb"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains four complete key-value pairs."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
tc_map[0]["arguments"] = full
|
||||
parsers: dict[int, Parser] = {}
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
|
||||
)
|
||||
|
||||
# Collect all argument deltas across packets
|
||||
all_deltas: dict[str, str] = {}
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
for k, v in p.obj.argument_deltas.items():
|
||||
all_deltas[k] = all_deltas.get(k, "") + v
|
||||
|
||||
assert all_deltas == {
|
||||
"a": "one",
|
||||
"b": "two",
|
||||
"c": "three",
|
||||
"d": "four",
|
||||
}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_on_second_arg_after_first_complete(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""First argument is fully complete; delta only adds to the second."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
|
||||
fragments = [
|
||||
'{"code": "print(1)", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "print(1)" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Non-string values (numbers, booleans, null) are skipped — they are
|
||||
available in the final tool-call kickoff packet. String arguments
|
||||
following them are still emitted."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"timeout": 30, "code": "hello"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert full == "hello"
|
||||
394
backend/tests/unit/onyx/utils/test_json_river.py
Normal file
394
backend/tests/unit/onyx/utils/test_json_river.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""Tests for the jsonriver incremental JSON parser."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.jsonriver import JsonValue
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _all_deltas(chunks: list[str]) -> list[JsonValue]:
|
||||
"""Feed chunks one at a time and collect all emitted deltas."""
|
||||
parser = Parser()
|
||||
deltas: list[JsonValue] = []
|
||||
for chunk in chunks:
|
||||
deltas.extend(parser.feed(chunk))
|
||||
deltas.extend(parser.finish())
|
||||
return deltas
|
||||
|
||||
|
||||
class TestParseComplete:
|
||||
"""Parsing complete JSON in a single chunk."""
|
||||
|
||||
def test_simple_object(self) -> None:
|
||||
deltas = _all_deltas(['{"a": 1}'])
|
||||
assert any(r == {"a": 1.0} or r == {"a": 1} for r in deltas)
|
||||
|
||||
def test_simple_array(self) -> None:
|
||||
deltas = _all_deltas(["[1, 2, 3]"])
|
||||
assert any(isinstance(r, list) for r in deltas)
|
||||
|
||||
def test_simple_string(self) -> None:
|
||||
deltas = _all_deltas(['"hello"'])
|
||||
assert "hello" in deltas or any("hello" in str(r) for r in deltas)
|
||||
|
||||
def test_null(self) -> None:
|
||||
deltas = _all_deltas(["null"])
|
||||
assert None in deltas
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
deltas = _all_deltas(["true"])
|
||||
assert True in deltas
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
deltas = _all_deltas(["false"])
|
||||
assert any(r is False for r in deltas)
|
||||
|
||||
def test_number(self) -> None:
|
||||
deltas = _all_deltas(["42"])
|
||||
assert 42.0 in deltas
|
||||
|
||||
def test_negative_number(self) -> None:
|
||||
deltas = _all_deltas(["-3.14"])
|
||||
assert any(abs(r - (-3.14)) < 1e-10 for r in deltas if isinstance(r, float))
|
||||
|
||||
def test_empty_object(self) -> None:
|
||||
deltas = _all_deltas(["{}"])
|
||||
assert {} in deltas
|
||||
|
||||
def test_empty_array(self) -> None:
|
||||
deltas = _all_deltas(["[]"])
|
||||
assert [] in deltas
|
||||
|
||||
|
||||
class TestStreamingDeltas:
|
||||
"""Incremental feeding produces correct deltas."""
|
||||
|
||||
def test_object_string_value_streamed_char_by_char(self) -> None:
|
||||
chunks = list('{"code": "abc"}')
|
||||
deltas = _all_deltas(chunks)
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "code" in d:
|
||||
val = d["code"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "abc"
|
||||
|
||||
def test_object_streamed_in_two_halves(self) -> None:
|
||||
deltas = _all_deltas(['{"name": "Al', 'ice"}'])
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "name" in d:
|
||||
val = d["name"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "Alice"
|
||||
|
||||
def test_multiple_keys_streamed(self) -> None:
|
||||
deltas = _all_deltas(['{"a": "x', '", "b": "y"}'])
|
||||
a_parts: list[str] = []
|
||||
b_parts: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "a" in d and isinstance(d["a"], str):
|
||||
a_parts.append(d["a"])
|
||||
if "b" in d and isinstance(d["b"], str):
|
||||
b_parts.append(d["b"])
|
||||
assert "".join(a_parts) == "x"
|
||||
assert "".join(b_parts) == "y"
|
||||
|
||||
def test_deltas_only_contain_new_string_content(self) -> None:
|
||||
parser = Parser()
|
||||
d1 = parser.feed('{"msg": "hel')
|
||||
d2 = parser.feed('lo"}')
|
||||
parser.finish()
|
||||
|
||||
msg_parts = []
|
||||
for d in d1 + d2:
|
||||
if isinstance(d, dict) and "msg" in d:
|
||||
val = d["msg"]
|
||||
if isinstance(val, str):
|
||||
msg_parts.append(val)
|
||||
assert "".join(msg_parts) == "hello"
|
||||
|
||||
# Each delta should only contain new chars, not repeat previous ones
|
||||
if len(msg_parts) == 2:
|
||||
assert msg_parts[0] == "hel"
|
||||
assert msg_parts[1] == "lo"
|
||||
|
||||
|
||||
class TestEscapeSequences:
|
||||
"""JSON escape sequences are decoded correctly, even across chunk boundaries."""
|
||||
|
||||
def test_newline_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"text": "line1\\nline2"}'])
|
||||
text_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "text" in d and isinstance(d["text"], str):
|
||||
text_parts.append(d["text"])
|
||||
assert "".join(text_parts) == "line1\nline2"
|
||||
|
||||
def test_tab_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"t": "a\\tb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "t" in d and isinstance(d["t"], str):
|
||||
parts.append(d["t"])
|
||||
assert "".join(parts) == "a\tb"
|
||||
|
||||
def test_escaped_quote(self) -> None:
|
||||
deltas = _all_deltas(['{"q": "say \\"hi\\""}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "q" in d and isinstance(d["q"], str):
|
||||
parts.append(d["q"])
|
||||
assert "".join(parts) == 'say "hi"'
|
||||
|
||||
def test_unicode_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u0041\\u0042"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "AB"
|
||||
|
||||
def test_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"x": "a\\', 'nb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "x" in d and isinstance(d["x"], str):
|
||||
parts.append(d["x"])
|
||||
assert "".join(parts) == "a\nb"
|
||||
|
||||
def test_unicode_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u00', '41"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "A"
|
||||
|
||||
def test_backslash_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"p": "c:\\\\dir"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "p" in d and isinstance(d["p"], str):
|
||||
parts.append(d["p"])
|
||||
assert "".join(parts) == "c:\\dir"
|
||||
|
||||
|
||||
class TestNestedStructures:
|
||||
"""Nested objects and arrays produce correct deltas."""
|
||||
|
||||
def test_nested_object(self) -> None:
|
||||
deltas = _all_deltas(['{"outer": {"inner": "val"}}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "outer" in d:
|
||||
outer = d["outer"]
|
||||
if isinstance(outer, dict) and "inner" in outer:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
def test_array_of_strings(self) -> None:
|
||||
deltas = _all_deltas(['["a', '", "b"]'])
|
||||
all_items: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, list):
|
||||
for item in d:
|
||||
if isinstance(item, str):
|
||||
all_items.append(item)
|
||||
elif isinstance(d, str):
|
||||
all_items.append(d)
|
||||
joined = "".join(all_items)
|
||||
assert "a" in joined
|
||||
assert "b" in joined
|
||||
|
||||
def test_object_with_number_and_bool(self) -> None:
|
||||
deltas = _all_deltas(['{"count": 42, "active": true}'])
|
||||
has_count = False
|
||||
has_active = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "count" in d and d["count"] == 42.0:
|
||||
has_count = True
|
||||
if "active" in d and d["active"] is True:
|
||||
has_active = True
|
||||
assert has_count
|
||||
assert has_active
|
||||
|
||||
def test_object_with_null_value(self) -> None:
|
||||
deltas = _all_deltas(['{"key": null}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "key" in d and d["key"] is None:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
|
||||
class TestComputeDelta:
|
||||
"""Direct tests for the _compute_delta static method."""
|
||||
|
||||
def test_none_prev_returns_current(self) -> None:
|
||||
assert Parser._compute_delta(None, {"a": "b"}) == {"a": "b"}
|
||||
|
||||
def test_string_delta(self) -> None:
|
||||
assert Parser._compute_delta("hel", "hello") == "lo"
|
||||
|
||||
def test_string_no_change(self) -> None:
|
||||
assert Parser._compute_delta("same", "same") is None
|
||||
|
||||
def test_dict_new_key(self) -> None:
|
||||
assert Parser._compute_delta({"a": "x"}, {"a": "x", "b": "y"}) == {"b": "y"}
|
||||
|
||||
def test_dict_string_append(self) -> None:
|
||||
assert Parser._compute_delta({"code": "def"}, {"code": "def hello()"}) == {
|
||||
"code": " hello()"
|
||||
}
|
||||
|
||||
def test_dict_no_change(self) -> None:
|
||||
assert Parser._compute_delta({"a": 1}, {"a": 1}) is None
|
||||
|
||||
def test_list_new_items(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2, 3]) == [3]
|
||||
|
||||
def test_list_last_item_updated(self) -> None:
|
||||
assert Parser._compute_delta(["a"], ["ab"]) == ["ab"]
|
||||
|
||||
def test_list_no_change(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2]) is None
|
||||
|
||||
def test_primitive_change(self) -> None:
|
||||
assert Parser._compute_delta(1, 2) == 2
|
||||
|
||||
def test_primitive_no_change(self) -> None:
|
||||
assert Parser._compute_delta(42, 42) is None
|
||||
|
||||
|
||||
class TestParserLifecycle:
|
||||
"""Edge cases around parser state and lifecycle."""
|
||||
|
||||
def test_feed_after_finish_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
parser.feed('{"a": 1}')
|
||||
parser.finish()
|
||||
assert parser.feed("more") == []
|
||||
|
||||
def test_empty_feed_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed("") == []
|
||||
|
||||
def test_whitespace_only_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed(" ") == []
|
||||
|
||||
def test_finish_with_trailing_whitespace(self) -> None:
|
||||
parser = Parser()
|
||||
# Trailing whitespace terminates the number, so feed() emits it
|
||||
deltas = parser.feed("42 ")
|
||||
assert 42.0 in deltas
|
||||
parser.finish() # Should not raise
|
||||
|
||||
def test_finish_with_trailing_content_raises(self) -> None:
|
||||
parser = Parser()
|
||||
# Feed a complete JSON value followed by non-whitespace in one chunk
|
||||
parser.feed('{"a": 1} extra')
|
||||
with pytest.raises(ValueError, match="Unexpected trailing"):
|
||||
parser.finish()
|
||||
|
||||
def test_finish_flushes_pending_number(self) -> None:
|
||||
parser = Parser()
|
||||
deltas = parser.feed("42")
|
||||
# Number has no terminator, so feed() can't emit it yet
|
||||
assert deltas == []
|
||||
final = parser.finish()
|
||||
assert 42.0 in final
|
||||
|
||||
|
||||
class TestToolCallSimulation:
|
||||
"""Simulate the LLM tool-call streaming use case."""
|
||||
|
||||
def test_python_tool_call_streaming(self) -> None:
|
||||
full_json = json.dumps({"code": "print('hello world')"})
|
||||
chunk_size = 5
|
||||
chunks = [
|
||||
full_json[i : i + chunk_size] for i in range(0, len(full_json), chunk_size)
|
||||
]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == "print('hello world')"
|
||||
|
||||
def test_multi_arg_tool_call(self) -> None:
|
||||
full = '{"query": "search term", "num_results": 5}'
|
||||
chunks = [full[:15], full[15:30], full[30:]]
|
||||
|
||||
parser = Parser()
|
||||
query_parts: list[str] = []
|
||||
has_num_results = False
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
assert "".join(query_parts) == "search term"
|
||||
assert has_num_results
|
||||
|
||||
def test_code_with_newlines_and_escapes(self) -> None:
|
||||
code = 'def greet(name):\n print(f"Hello, {name}!")\n return True'
|
||||
full = json.dumps({"code": code})
|
||||
chunk_size = 8
|
||||
chunks = [full[i : i + chunk_size] for i in range(0, len(full), chunk_size)]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == code
|
||||
|
||||
def test_single_char_streaming(self) -> None:
|
||||
full = '{"key": "value"}'
|
||||
parser = Parser()
|
||||
key_parts: list[str] = []
|
||||
for ch in full:
|
||||
for delta in parser.feed(ch):
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
assert "".join(key_parts) == "value"
|
||||
@@ -20,8 +20,6 @@ from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# Import the function under test
|
||||
|
||||
|
||||
class TestBuildVespaFilters:
|
||||
def test_empty_filters(self) -> None:
|
||||
@@ -179,11 +177,27 @@ class TestBuildVespaFilters:
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_project_filter(self) -> None:
|
||||
"""Test user project filtering (replacement for user folder IDs)."""
|
||||
# Single project id
|
||||
"""Test user project filtering.
|
||||
|
||||
project_id alone does NOT trigger a knowledge scope restriction
|
||||
(an agent with no explicit knowledge should search everything).
|
||||
It only participates when explicit knowledge filters are present.
|
||||
"""
|
||||
# project_id alone → no restriction
|
||||
filters = IndexFilters(access_control_list=[], project_id=789)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({USER_PROJECT} contains "789") and ' == result
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# project_id with user_file_ids → both OR'd
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], project_id=789, user_file_ids=[id1]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
|
||||
== result
|
||||
)
|
||||
|
||||
# No project id
|
||||
filters = IndexFilters(access_control_list=[], project_id=None)
|
||||
@@ -217,7 +231,11 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
|
||||
def test_combined_filters(self) -> None:
|
||||
"""Test combining multiple filter types."""
|
||||
"""Test combining multiple filter types.
|
||||
|
||||
Knowledge-scope filters (document_set, user_file_ids, project_id,
|
||||
persona_id) are OR'd together, while all other filters are AND'd.
|
||||
"""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=["user1", "group1"],
|
||||
@@ -231,7 +249,6 @@ class TestBuildVespaFilters:
|
||||
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
# Build expected result piece by piece for readability
|
||||
expected = f"!({HIDDEN}=true) and "
|
||||
expected += (
|
||||
'(access_control_list contains "user1" or '
|
||||
@@ -239,9 +256,13 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
expected += f'({DOCUMENT_SETS} contains "set1") and '
|
||||
expected += f'({DOCUMENT_ID} contains "{str(id1)}") and '
|
||||
expected += f'({USER_PROJECT} contains "789") and '
|
||||
# Knowledge scope filters are OR'd together
|
||||
expected += (
|
||||
f'(({DOCUMENT_SETS} contains "set1")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f' or ({USER_PROJECT} contains "789")'
|
||||
f") and "
|
||||
)
|
||||
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
|
||||
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
|
||||
@@ -251,6 +272,32 @@ class TestBuildVespaFilters:
|
||||
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
assert expected[:-5] == result_no_trailing # Remove trailing " and "
|
||||
|
||||
def test_knowledge_scope_single_filter_not_wrapped(self) -> None:
|
||||
"""When only one knowledge-scope filter is present it should not
|
||||
be wrapped in an extra OR group."""
|
||||
filters = IndexFilters(access_control_list=[], document_set=["set1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
|
||||
|
||||
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
|
||||
"""Document set filter and user file IDs must be OR'd so that
|
||||
connector documents (in the set) and user files (with specific
|
||||
IDs) can both be found."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
document_set=["engineering"],
|
||||
user_file_ids=[id1],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = (
|
||||
f"!({HIDDEN}=true) and "
|
||||
f'(({DOCUMENT_SETS} contains "engineering")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f") and "
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
def test_empty_or_none_values(self) -> None:
|
||||
"""Test with empty or None values in filter lists."""
|
||||
# Empty strings in document set
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# Onyx CLI
|
||||
|
||||
[](https://github.com/onyx-dot-app/onyx/actions/workflows/release-cli.yml)
|
||||
[](https://pypi.org/project/onyx-cli/)
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
@@ -28,7 +31,7 @@ Environment variables override config file values:
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
@@ -68,17 +71,17 @@ onyx-cli agents --json
|
||||
| `ask` | Ask a one-shot question (non-interactive) |
|
||||
| `agents` | List available agents |
|
||||
| `configure` | Configure server URL and API key |
|
||||
| `validate-config` | Validate configuration and test connection |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show help message |
|
||||
| `/new` | Start a new chat session |
|
||||
| `/clear` | Clear chat and start a new session |
|
||||
| `/agent` | List and switch agents |
|
||||
| `/attach <path>` | Attach a file to next message |
|
||||
| `/sessions` | List recent chat sessions |
|
||||
| `/clear` | Clear the chat display |
|
||||
| `/configure` | Re-run connection setup |
|
||||
| `/connectors` | Open connectors in browser |
|
||||
| `/settings` | Open settings in browser |
|
||||
@@ -116,3 +119,43 @@ go build -o onyx-cli .
|
||||
# Lint
|
||||
staticcheck ./...
|
||||
```
|
||||
|
||||
## Publishing to PyPI
|
||||
|
||||
The CLI is distributed as a Python package via [PyPI](https://pypi.org/project/onyx-cli/). The build system uses [hatchling](https://hatch.pypa.io/) with [manygo](https://github.com/nicholasgasior/manygo) to cross-compile Go binaries into platform-specific wheels.
|
||||
|
||||
### CI release (recommended)
|
||||
|
||||
Tag a release and push — the `release-cli.yml` workflow builds wheels for all platforms and publishes to PyPI automatically:
|
||||
|
||||
```shell
|
||||
tag --prefix cli
|
||||
```
|
||||
|
||||
To do this manually:
|
||||
|
||||
```shell
|
||||
git tag cli/v0.1.0
|
||||
git push origin cli/v0.1.0
|
||||
```
|
||||
|
||||
The workflow builds wheels for: linux/amd64, linux/arm64, darwin/amd64, darwin/arm64, windows/amd64, windows/arm64.
|
||||
|
||||
### Manual release
|
||||
|
||||
Build a wheel locally with `uv`. Set `GOOS` and `GOARCH` to cross-compile for other platforms (Go handles this natively — no cross-compiler needed):
|
||||
|
||||
```shell
|
||||
# Build for current platform
|
||||
uv build --wheel
|
||||
|
||||
# Cross-compile for a different platform
|
||||
GOOS=linux GOARCH=amd64 uv build --wheel
|
||||
|
||||
# Upload to PyPI
|
||||
uv publish
|
||||
```
|
||||
|
||||
### Versioning
|
||||
|
||||
Versions are derived from git tags with the `cli/` prefix (e.g. `cli/v0.1.0`). The tag is parsed by `internal/_version.py` and injected into the Go binary via `-ldflags` at build time.
|
||||
|
||||
43
cli/hatch_build.py
Normal file
43
cli/hatch_build.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
import manygo
|
||||
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
|
||||
|
||||
|
||||
class CustomBuildHook(BuildHookInterface):
|
||||
"""Build hook to compile the Go binary and include it in the wheel."""
|
||||
|
||||
def initialize(self, version: Any, build_data: Any) -> None: # noqa: ARG002
|
||||
"""Build the Go binary before packaging."""
|
||||
build_data["pure_python"] = False
|
||||
|
||||
# Set platform tag for cross-compilation
|
||||
goos = os.getenv("GOOS")
|
||||
goarch = os.getenv("GOARCH")
|
||||
if manygo.is_goos(goos) and manygo.is_goarch(goarch):
|
||||
build_data["tag"] = "py3-none-" + manygo.get_platform_tag(
|
||||
goos=goos,
|
||||
goarch=goarch,
|
||||
)
|
||||
|
||||
# Get config and environment
|
||||
binary_name = self.config["binary_name"]
|
||||
tag_prefix = self.config.get("tag_prefix", binary_name)
|
||||
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{tag_prefix}/")
|
||||
commit = os.getenv("GITHUB_SHA", "none")
|
||||
|
||||
# Build the Go binary if it doesn't exist
|
||||
# Build the Go binary (always rebuild to ensure correct version injection)
|
||||
if not os.path.exists(binary_name):
|
||||
print(f"Building Go binary '{binary_name}'...")
|
||||
pkg = "github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
ldflags = f"-X {pkg}.version={tag}" f" -X {pkg}.commit={commit}" " -s -w"
|
||||
subprocess.check_call( # noqa: S603
|
||||
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
|
||||
)
|
||||
|
||||
build_data["shared_scripts"] = {binary_name: binary_name}
|
||||
11
cli/internal/_version.py
Normal file
11
cli/internal/_version.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
# Must match tag_prefix in pyproject.toml [tool.hatch.build.targets.wheel.hooks.custom]
|
||||
TAG_PREFIX = "cli"
|
||||
|
||||
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix(f"{TAG_PREFIX}/")
|
||||
_match = re.search(r"v?(\d+\.\d+\.\d+)", _tag)
|
||||
__version__ = _match.group(1) if _match else "0.0.0"
|
||||
39
cli/pyproject.toml
Normal file
39
cli/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[build-system]
|
||||
requires = ["hatchling", "go-bin~=1.24.11", "manygo"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "onyx-cli"
|
||||
readme = "README.md"
|
||||
description = "Terminal interface for chatting with your Onyx agent"
|
||||
authors = [{ name = "Onyx AI", email = "founders@onyx.app" }]
|
||||
requires-python = ">=3.9"
|
||||
keywords = [
|
||||
"onyx", "cli", "chat", "ai", "enterprise-search",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Go",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
[tool.hatch.build]
|
||||
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py", "README.md"]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "code"
|
||||
path = "internal/_version.py"
|
||||
|
||||
[tool.hatch.build.targets.wheel.hooks.custom]
|
||||
path = "hatch_build.py"
|
||||
binary_name = "onyx-cli"
|
||||
tag_prefix = "cli"
|
||||
|
||||
[tool.uv]
|
||||
managed = false
|
||||
104
desktop/package-lock.json
generated
104
desktop/package-lock.json
generated
@@ -8,16 +8,16 @@
|
||||
"name": "onyx-desktop",
|
||||
"version": "0.0.0-dev",
|
||||
"dependencies": {
|
||||
"@tauri-apps/api": "^2.0.0"
|
||||
"@tauri-apps/api": "^2.10.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tauri-apps/cli": "^2.0.0"
|
||||
"@tauri-apps/cli": "^2.10.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/api": {
|
||||
"version": "2.9.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/api/-/api-2.9.1.tgz",
|
||||
"integrity": "sha512-IGlhP6EivjXHepbBic618GOmiWe4URJiIeZFlB7x3czM0yDHHYviH1Xvoiv4FefdkQtn6v7TuwWCRfOGdnVUGw==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/api/-/api-2.10.1.tgz",
|
||||
"integrity": "sha512-hKL/jWf293UDSUN09rR69hrToyIXBb8CjGaWC7gfinvnQrBVvnLr08FeFi38gxtugAVyVcTa5/FD/Xnkb1siBw==",
|
||||
"license": "Apache-2.0 OR MIT",
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
@@ -25,9 +25,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli/-/cli-2.9.6.tgz",
|
||||
"integrity": "sha512-3xDdXL5omQ3sPfBfdC8fCtDKcnyV7OqyzQgfyT5P3+zY6lcPqIYKQBvUasNvppi21RSdfhy44ttvJmftb0PCDw==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli/-/cli-2.10.1.tgz",
|
||||
"integrity": "sha512-jQNGF/5quwORdZSSLtTluyKQ+o6SMa/AUICfhf4egCGFdMHqWssApVgYSbg+jmrZoc8e1DscNvjTnXtlHLS11g==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0 OR MIT",
|
||||
"bin": {
|
||||
@@ -41,23 +41,23 @@
|
||||
"url": "https://opencollective.com/tauri"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@tauri-apps/cli-darwin-arm64": "2.9.6",
|
||||
"@tauri-apps/cli-darwin-x64": "2.9.6",
|
||||
"@tauri-apps/cli-linux-arm-gnueabihf": "2.9.6",
|
||||
"@tauri-apps/cli-linux-arm64-gnu": "2.9.6",
|
||||
"@tauri-apps/cli-linux-arm64-musl": "2.9.6",
|
||||
"@tauri-apps/cli-linux-riscv64-gnu": "2.9.6",
|
||||
"@tauri-apps/cli-linux-x64-gnu": "2.9.6",
|
||||
"@tauri-apps/cli-linux-x64-musl": "2.9.6",
|
||||
"@tauri-apps/cli-win32-arm64-msvc": "2.9.6",
|
||||
"@tauri-apps/cli-win32-ia32-msvc": "2.9.6",
|
||||
"@tauri-apps/cli-win32-x64-msvc": "2.9.6"
|
||||
"@tauri-apps/cli-darwin-arm64": "2.10.1",
|
||||
"@tauri-apps/cli-darwin-x64": "2.10.1",
|
||||
"@tauri-apps/cli-linux-arm-gnueabihf": "2.10.1",
|
||||
"@tauri-apps/cli-linux-arm64-gnu": "2.10.1",
|
||||
"@tauri-apps/cli-linux-arm64-musl": "2.10.1",
|
||||
"@tauri-apps/cli-linux-riscv64-gnu": "2.10.1",
|
||||
"@tauri-apps/cli-linux-x64-gnu": "2.10.1",
|
||||
"@tauri-apps/cli-linux-x64-musl": "2.10.1",
|
||||
"@tauri-apps/cli-win32-arm64-msvc": "2.10.1",
|
||||
"@tauri-apps/cli-win32-ia32-msvc": "2.10.1",
|
||||
"@tauri-apps/cli-win32-x64-msvc": "2.10.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-darwin-arm64": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-darwin-arm64/-/cli-darwin-arm64-2.9.6.tgz",
|
||||
"integrity": "sha512-gf5no6N9FCk1qMrti4lfwP77JHP5haASZgVbBgpZG7BUepB3fhiLCXGUK8LvuOjP36HivXewjg72LTnPDScnQQ==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-darwin-arm64/-/cli-darwin-arm64-2.10.1.tgz",
|
||||
"integrity": "sha512-Z2OjCXiZ+fbYZy7PmP3WRnOpM9+Fy+oonKDEmUE6MwN4IGaYqgceTjwHucc/kEEYZos5GICve35f7ZiizgqEnQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -72,9 +72,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-darwin-x64": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-darwin-x64/-/cli-darwin-x64-2.9.6.tgz",
|
||||
"integrity": "sha512-oWh74WmqbERwwrwcueJyY6HYhgCksUc6NT7WKeXyrlY/FPmNgdyQAgcLuTSkhRFuQ6zh4Np1HZpOqCTpeZBDcw==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-darwin-x64/-/cli-darwin-x64-2.10.1.tgz",
|
||||
"integrity": "sha512-V/irQVvjPMGOTQqNj55PnQPVuH4VJP8vZCN7ajnj+ZS8Kom1tEM2hR3qbbIRoS3dBKs5mbG8yg1WC+97dq17Pw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -89,9 +89,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-arm-gnueabihf": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm-gnueabihf/-/cli-linux-arm-gnueabihf-2.9.6.tgz",
|
||||
"integrity": "sha512-/zde3bFroFsNXOHN204DC2qUxAcAanUjVXXSdEGmhwMUZeAQalNj5cz2Qli2elsRjKN/hVbZOJj0gQ5zaYUjSg==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm-gnueabihf/-/cli-linux-arm-gnueabihf-2.10.1.tgz",
|
||||
"integrity": "sha512-Hyzwsb4VnCWKGfTw+wSt15Z2pLw2f0JdFBfq2vHBOBhvg7oi6uhKiF87hmbXOBXUZaGkyRDkCHsdzJcIfoJC2w==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
@@ -106,9 +106,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-arm64-gnu": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm64-gnu/-/cli-linux-arm64-gnu-2.9.6.tgz",
|
||||
"integrity": "sha512-pvbljdhp9VOo4RnID5ywSxgBs7qiylTPlK56cTk7InR3kYSTJKYMqv/4Q/4rGo/mG8cVppesKIeBMH42fw6wjg==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm64-gnu/-/cli-linux-arm64-gnu-2.10.1.tgz",
|
||||
"integrity": "sha512-OyOYs2t5GkBIvyWjA1+h4CZxTcdz1OZPCWAPz5DYEfB0cnWHERTnQ/SLayQzncrT0kwRoSfSz9KxenkyJoTelA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -123,9 +123,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-arm64-musl": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm64-musl/-/cli-linux-arm64-musl-2.9.6.tgz",
|
||||
"integrity": "sha512-02TKUndpodXBCR0oP//6dZWGYcc22Upf2eP27NvC6z0DIqvkBBFziQUcvi2n6SrwTRL0yGgQjkm9K5NIn8s6jw==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-arm64-musl/-/cli-linux-arm64-musl-2.10.1.tgz",
|
||||
"integrity": "sha512-MIj78PDDGjkg3NqGptDOGgfXks7SYJwhiMh8SBoZS+vfdz7yP5jN18bNaLnDhsVIPARcAhE1TlsZe/8Yxo2zqg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -140,9 +140,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-riscv64-gnu": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-riscv64-gnu/-/cli-linux-riscv64-gnu-2.9.6.tgz",
|
||||
"integrity": "sha512-fmp1hnulbqzl1GkXl4aTX9fV+ubHw2LqlLH1PE3BxZ11EQk+l/TmiEongjnxF0ie4kV8DQfDNJ1KGiIdWe1GvQ==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-riscv64-gnu/-/cli-linux-riscv64-gnu-2.10.1.tgz",
|
||||
"integrity": "sha512-X0lvOVUg8PCVaoEtEAnpxmnkwlE1gcMDTqfhbefICKDnOTJ5Est3qL0SrWxizDackIOKBcvtpejrSiVpuJI1kw==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
@@ -157,9 +157,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-x64-gnu": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-x64-gnu/-/cli-linux-x64-gnu-2.9.6.tgz",
|
||||
"integrity": "sha512-vY0le8ad2KaV1PJr+jCd8fUF9VOjwwQP/uBuTJvhvKTloEwxYA/kAjKK9OpIslGA9m/zcnSo74czI6bBrm2sYA==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-x64-gnu/-/cli-linux-x64-gnu-2.10.1.tgz",
|
||||
"integrity": "sha512-2/12bEzsJS9fAKybxgicCDFxYD1WEI9kO+tlDwX5znWG2GwMBaiWcmhGlZ8fi+DMe9CXlcVarMTYc0L3REIRxw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -174,9 +174,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-linux-x64-musl": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-x64-musl/-/cli-linux-x64-musl-2.9.6.tgz",
|
||||
"integrity": "sha512-TOEuB8YCFZTWVDzsO2yW0+zGcoMiPPwcUgdnW1ODnmgfwccpnihDRoks+ABT1e3fHb1ol8QQWsHSCovb3o2ENQ==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-linux-x64-musl/-/cli-linux-x64-musl-2.10.1.tgz",
|
||||
"integrity": "sha512-Y8J0ZzswPz50UcGOFuXGEMrxbjwKSPgXftx5qnkuMs2rmwQB5ssvLb6tn54wDSYxe7S6vlLob9vt0VKuNOaCIQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -191,9 +191,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-win32-arm64-msvc": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-arm64-msvc/-/cli-win32-arm64-msvc-2.9.6.tgz",
|
||||
"integrity": "sha512-ujmDGMRc4qRLAnj8nNG26Rlz9klJ0I0jmZs2BPpmNNf0gM/rcVHhqbEkAaHPTBVIrtUdf7bGvQAD2pyIiUrBHQ==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-arm64-msvc/-/cli-win32-arm64-msvc-2.10.1.tgz",
|
||||
"integrity": "sha512-iSt5B86jHYAPJa/IlYw++SXtFPGnWtFJriHn7X0NFBVunF6zu9+/zOn8OgqIWSl8RgzhLGXQEEtGBdR4wzpVgg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -208,9 +208,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-win32-ia32-msvc": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-ia32-msvc/-/cli-win32-ia32-msvc-2.9.6.tgz",
|
||||
"integrity": "sha512-S4pT0yAJgFX8QRCyKA1iKjZ9Q/oPjCZf66A/VlG5Yw54Nnr88J1uBpmenINbXxzyhduWrIXBaUbEY1K80ZbpMg==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-ia32-msvc/-/cli-win32-ia32-msvc-2.10.1.tgz",
|
||||
"integrity": "sha512-gXyxgEzsFegmnWywYU5pEBURkcFN/Oo45EAwvZrHMh+zUSEAvO5E8TXsgPADYm31d1u7OQU3O3HsYfVBf2moHw==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
@@ -225,9 +225,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tauri-apps/cli-win32-x64-msvc": {
|
||||
"version": "2.9.6",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-x64-msvc/-/cli-win32-x64-msvc-2.9.6.tgz",
|
||||
"integrity": "sha512-ldWuWSSkWbKOPjQMJoYVj9wLHcOniv7diyI5UAJ4XsBdtaFB0pKHQsqw/ItUma0VXGC7vB4E9fZjivmxur60aw==",
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@tauri-apps/cli-win32-x64-msvc/-/cli-win32-x64-msvc-2.10.1.tgz",
|
||||
"integrity": "sha512-6Cn7YpPFwzChy0ERz6djKEmUehWrYlM+xTaNzGPgZocw3BD7OfwfWHKVWxXzdjEW2KfKkHddfdxK1XXTYqBRLg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
"build:dmg": "tauri build --target universal-apple-darwin",
|
||||
"build:linux": "tauri build --bundles deb,rpm"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tauri-apps/cli": "^2.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tauri-apps/api": "^2.0.0"
|
||||
"@tauri-apps/api": "^2.10.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tauri-apps/cli": "^2.10.1"
|
||||
}
|
||||
}
|
||||
|
||||
1174
desktop/src-tauri/Cargo.lock
generated
1174
desktop/src-tauri/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -6,18 +6,18 @@ authors = ["you"]
|
||||
edition = "2021"
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2.0", features = [] }
|
||||
tauri-build = { version = "2.5", features = [] }
|
||||
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
tauri = { version = "2.10", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.3.5"
|
||||
tauri-plugin-window-state = "2.4.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
directories = "5.0"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
window-vibrancy = "0.5"
|
||||
window-vibrancy = "0.7.1"
|
||||
url = "2.5"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -26,20 +26,16 @@ class CustomBuildHook(BuildHookInterface):
|
||||
|
||||
# Get config and environment
|
||||
binary_name = self.config["binary_name"]
|
||||
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{binary_name}/")
|
||||
tag_prefix = self.config.get("tag_prefix", binary_name)
|
||||
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{tag_prefix}/")
|
||||
commit = os.getenv("GITHUB_SHA", "none")
|
||||
|
||||
# Build the Go binary if it doesn't exist
|
||||
if not os.path.exists(binary_name):
|
||||
print(f"Building Go binary '{binary_name}'...")
|
||||
ldflags = f"-X main.version={tag} -X main.commit={commit} -s -w"
|
||||
subprocess.check_call( # noqa: S603
|
||||
[
|
||||
"go",
|
||||
"build",
|
||||
f"-ldflags=-X main.version={tag} -X main.commit={commit} -s -w",
|
||||
"-o",
|
||||
binary_name,
|
||||
],
|
||||
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
|
||||
)
|
||||
|
||||
build_data["shared_scripts"] = {binary_name: binary_name}
|
||||
|
||||
@@ -3,6 +3,9 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
|
||||
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix("ods/")
|
||||
# Must match tag_prefix in pyproject.toml [tool.hatch.build.targets.wheel.hooks.custom]
|
||||
TAG_PREFIX: str = "ods"
|
||||
|
||||
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix(f"{TAG_PREFIX}/")
|
||||
_match = re.search(r"v?(\d+\.\d+\.\d+)", _tag)
|
||||
__version__ = _match.group(1) if _match else "0.0.0"
|
||||
|
||||
@@ -14,7 +14,9 @@ keywords = [
|
||||
classifiers = [
|
||||
"Programming Language :: Go",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
@@ -27,7 +29,7 @@ dependencies = [
|
||||
Repository = "https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
[tool.hatch.build]
|
||||
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py"]
|
||||
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py", "README.md"]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "code"
|
||||
@@ -36,6 +38,7 @@ path = "internal/_version.py"
|
||||
[tool.hatch.build.targets.wheel.hooks.custom]
|
||||
path = "hatch_build.py"
|
||||
binary_name = "ods"
|
||||
tag_prefix = "ods"
|
||||
|
||||
[tool.uv]
|
||||
managed = false
|
||||
|
||||
928
web/package-lock.json
generated
928
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -132,7 +132,7 @@
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"identity-obj-proxy": "^3.0.0",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0",
|
||||
"jest-environment-jsdom": "^30.2.0",
|
||||
"prettier": "3.1.0",
|
||||
"stats.js": "^0.17.0",
|
||||
"tailwindcss": "^3.4.17",
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { FormikField } from "@/refresh-components/form/FormikField";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
|
||||
@@ -107,26 +107,25 @@ export default function OnyxApiKeyForm({
|
||||
<FormField name="role" state={state} className="w-full">
|
||||
<FormField.Label>Role:</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputComboBox
|
||||
<InputSelect
|
||||
value={field.value}
|
||||
onValueChange={(value) => helper.setValue(value)}
|
||||
options={[
|
||||
{
|
||||
label: USER_ROLE_LABELS[UserRole.LIMITED],
|
||||
value: UserRole.LIMITED.toString(),
|
||||
},
|
||||
{
|
||||
label: USER_ROLE_LABELS[UserRole.BASIC],
|
||||
value: UserRole.BASIC.toString(),
|
||||
},
|
||||
{
|
||||
label: USER_ROLE_LABELS[UserRole.ADMIN],
|
||||
value: UserRole.ADMIN.toString(),
|
||||
},
|
||||
]}
|
||||
placeholder="Select a role"
|
||||
strict
|
||||
/>
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a role" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item
|
||||
value={UserRole.LIMITED.toString()}
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.LIMITED]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value={UserRole.BASIC.toString()}>
|
||||
{USER_ROLE_LABELS[UserRole.BASIC]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value={UserRole.ADMIN.toString()}>
|
||||
{USER_ROLE_LABELS[UserRole.ADMIN]}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</FormField.Control>
|
||||
<FormField.Description>
|
||||
Select the role for this API key. Limited has access to
|
||||
|
||||
@@ -7,15 +7,13 @@ import { InMessageImage } from "@/app/app/components/files/images/InMessageImage
|
||||
import CsvContent from "@/components/tools/CSVContent";
|
||||
import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { cn } from "@/lib/utils";
|
||||
import ExpandableContentWrapper from "@/components/tools/ExpandableContentWrapper";
|
||||
|
||||
interface FileDisplayProps {
|
||||
files: FileDescriptor[];
|
||||
alignBubble?: boolean;
|
||||
}
|
||||
|
||||
export default function FileDisplay({ files, alignBubble }: FileDisplayProps) {
|
||||
export default function FileDisplay({ files }: FileDisplayProps) {
|
||||
const [close, setClose] = useState(true);
|
||||
const [previewingFile, setPreviewingFile] = useState<FileDescriptor | null>(
|
||||
null
|
||||
@@ -43,59 +41,47 @@ export default function FileDisplay({ files, alignBubble }: FileDisplayProps) {
|
||||
)}
|
||||
|
||||
{textFiles.length > 0 && (
|
||||
<div
|
||||
id="onyx-file"
|
||||
className={cn("m-2 auto", alignBubble && "ml-auto")}
|
||||
>
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
{textFiles.map((file) => (
|
||||
<Attachment
|
||||
key={file.id}
|
||||
fileName={file.name || file.id}
|
||||
open={() => setPreviewingFile(file)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<div id="onyx-file" className="flex flex-col items-end gap-2 py-2">
|
||||
{textFiles.map((file) => (
|
||||
<Attachment
|
||||
key={file.id}
|
||||
fileName={file.name || file.id}
|
||||
open={() => setPreviewingFile(file)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{imageFiles.length > 0 && (
|
||||
<div
|
||||
id="onyx-image"
|
||||
className={cn("m-2 auto", alignBubble && "ml-auto")}
|
||||
>
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
{imageFiles.map((file) => (
|
||||
<InMessageImage key={file.id} fileId={file.id} />
|
||||
))}
|
||||
</div>
|
||||
<div id="onyx-image" className="flex flex-col items-end gap-2 py-2">
|
||||
{imageFiles.map((file) => (
|
||||
<InMessageImage key={file.id} fileId={file.id} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{csvFiles.length > 0 && (
|
||||
<div className={cn("m-2 auto", alignBubble && "ml-auto")}>
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
{csvFiles.map((file) => {
|
||||
return (
|
||||
<div key={file.id} className="w-fit">
|
||||
{close ? (
|
||||
<>
|
||||
<ExpandableContentWrapper
|
||||
fileDescriptor={file}
|
||||
close={() => setClose(false)}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Attachment
|
||||
open={() => setClose(true)}
|
||||
fileName={file.name || file.id}
|
||||
<div className="flex flex-col items-end gap-2 py-2">
|
||||
{csvFiles.map((file) => {
|
||||
return (
|
||||
<div key={file.id} className="w-fit">
|
||||
{close ? (
|
||||
<>
|
||||
<ExpandableContentWrapper
|
||||
fileDescriptor={file}
|
||||
close={() => setClose(false)}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<Attachment
|
||||
open={() => setClose(true)}
|
||||
fileName={file.name || file.id}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -195,7 +195,7 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
id="onyx-human-message"
|
||||
className="group flex flex-col justify-end w-full relative"
|
||||
>
|
||||
<FileDisplay alignBubble files={files || []} />
|
||||
<FileDisplay files={files || []} />
|
||||
{isEditing ? (
|
||||
<MessageEditing
|
||||
content={content}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import React, { JSX, memo } from "react";
|
||||
import {
|
||||
ChatPacket,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
ImageGenerationToolPacket,
|
||||
Packet,
|
||||
PacketType,
|
||||
ReasoningPacket,
|
||||
SearchToolStart,
|
||||
StopReason,
|
||||
ToolCallArgumentDelta,
|
||||
} from "../../services/streamingModels";
|
||||
import {
|
||||
FullChatState,
|
||||
@@ -26,7 +29,6 @@ import { DeepResearchPlanRenderer } from "./timeline/renderers/deepresearch/Deep
|
||||
import { ResearchAgentRenderer } from "./timeline/renderers/deepresearch/ResearchAgentRenderer";
|
||||
import { WebSearchToolRenderer } from "./timeline/renderers/search/WebSearchToolRenderer";
|
||||
import { InternalSearchToolRenderer } from "./timeline/renderers/search/InternalSearchToolRenderer";
|
||||
import { SearchToolStart } from "../../services/streamingModels";
|
||||
|
||||
// Different types of chat packets using discriminated unions
|
||||
interface GroupedPackets {
|
||||
@@ -56,7 +58,12 @@ function isImageToolPacket(packet: Packet) {
|
||||
}
|
||||
|
||||
function isPythonToolPacket(packet: Packet) {
|
||||
return packet.obj.type === PacketType.PYTHON_TOOL_START;
|
||||
return (
|
||||
packet.obj.type === PacketType.PYTHON_TOOL_START ||
|
||||
(packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
}
|
||||
|
||||
function isCustomToolPacket(packet: Packet) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
Stop,
|
||||
ImageGenerationToolDelta,
|
||||
MessageStart,
|
||||
ToolCallArgumentDelta,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
import { CitationMap } from "@/app/app/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
@@ -138,6 +140,7 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.TOOL_CALL_ARGUMENT_DELTA,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.FILE_READER_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
@@ -149,9 +152,16 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
]);
|
||||
|
||||
function hasContentPackets(packets: Packet[]): boolean {
|
||||
return packets.some((packet) =>
|
||||
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
|
||||
);
|
||||
return packets.some((packet) => {
|
||||
const type = packet.obj.type as PacketType;
|
||||
if (type === PacketType.TOOL_CALL_ARGUMENT_DELTA) {
|
||||
return (
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
);
|
||||
}
|
||||
return CONTENT_PACKET_TYPES_SET.has(type);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import { Packet, PacketType } from "@/app/app/services/streamingModels";
|
||||
import {
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
Packet,
|
||||
PacketType,
|
||||
ToolCallArgumentDelta,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
|
||||
// Packet types with renderers supporting collapsed streaming mode
|
||||
// Packet types with renderers supporting collapsed streaming mode.
|
||||
// TOOL_CALL_ARGUMENT_DELTA is intentionally excluded here because it requires
|
||||
// a tool_type check — it's handled separately in stepSupportsCollapsedStreaming.
|
||||
export const COLLAPSED_STREAMING_PACKET_TYPES = new Set<PacketType>([
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
@@ -21,7 +28,13 @@ export const isSearchToolPackets = (packets: Packet[]): boolean =>
|
||||
|
||||
// Check if packets belong to a python tool
|
||||
export const isPythonToolPackets = (packets: Packet[]): boolean =>
|
||||
packets.some((p) => p.obj.type === PacketType.PYTHON_TOOL_START);
|
||||
packets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.PYTHON_TOOL_START ||
|
||||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
|
||||
// Check if packets belong to reasoning
|
||||
export const isReasoningPackets = (packets: Packet[]): boolean =>
|
||||
@@ -29,8 +42,12 @@ export const isReasoningPackets = (packets: Packet[]): boolean =>
|
||||
|
||||
// Check if step supports collapsed streaming rendering mode
|
||||
export const stepSupportsCollapsedStreaming = (packets: Packet[]): boolean =>
|
||||
packets.some((p) =>
|
||||
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType)
|
||||
packets.some(
|
||||
(p) =>
|
||||
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType) ||
|
||||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
|
||||
// Check if packets have content worth rendering in collapsed streaming mode.
|
||||
@@ -67,7 +84,13 @@ export const stepHasCollapsedStreamingContent = (
|
||||
// Python tool renders code/output from the start packet onward
|
||||
if (
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_START) ||
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_DELTA)
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_DELTA) ||
|
||||
packets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import {
|
||||
PythonToolPacket,
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
ToolCallArgumentDelta,
|
||||
SectionEnd,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
@@ -39,6 +41,18 @@ function HighlightedPythonCode({ code }: { code: string }) {
|
||||
|
||||
// Helper function to construct current Python execution state
|
||||
function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
// Accumulate streaming code from argument deltas (arrives before PythonToolStart)
|
||||
const streamingCode = packets
|
||||
.filter(
|
||||
(packet) =>
|
||||
packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
)
|
||||
.map((packet) =>
|
||||
String((packet.obj as ToolCallArgumentDelta).argument_deltas.code ?? "")
|
||||
)
|
||||
.join("");
|
||||
const pythonStart = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.PYTHON_TOOL_START
|
||||
)?.obj as PythonToolStart | null;
|
||||
@@ -51,7 +65,8 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
packet.obj.type === PacketType.ERROR
|
||||
)?.obj as SectionEnd | null;
|
||||
|
||||
const code = pythonStart?.code || "";
|
||||
// Use complete code from PythonToolStart if available, else use streamed code.
|
||||
const code = pythonStart?.code || streamingCode;
|
||||
const stdout = pythonDeltas
|
||||
.map((delta) => delta?.stdout || "")
|
||||
.filter((s) => s)
|
||||
@@ -61,6 +76,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
.filter((s) => s)
|
||||
.join("");
|
||||
const fileIds = pythonDeltas.flatMap((delta) => delta?.file_ids || []);
|
||||
const isStreaming = !pythonStart && streamingCode.length > 0;
|
||||
const isExecuting = pythonStart && !pythonEnd;
|
||||
const isComplete = pythonStart && pythonEnd;
|
||||
const hasError = stderr.length > 0;
|
||||
@@ -70,6 +86,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
stdout,
|
||||
stderr,
|
||||
fileIds,
|
||||
isStreaming,
|
||||
isExecuting,
|
||||
isComplete,
|
||||
hasError,
|
||||
@@ -82,8 +99,16 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const { code, stdout, stderr, fileIds, isExecuting, isComplete, hasError } =
|
||||
constructCurrentPythonState(packets);
|
||||
const {
|
||||
code,
|
||||
stdout,
|
||||
stderr,
|
||||
fileIds,
|
||||
isStreaming,
|
||||
isExecuting,
|
||||
isComplete,
|
||||
hasError,
|
||||
} = constructCurrentPythonState(packets);
|
||||
|
||||
useEffect(() => {
|
||||
if (isComplete) {
|
||||
@@ -92,6 +117,9 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
}, [isComplete, onComplete]);
|
||||
|
||||
const status = useMemo(() => {
|
||||
if (isStreaming) {
|
||||
return "Writing code...";
|
||||
}
|
||||
if (isExecuting) {
|
||||
return "Executing Python code...";
|
||||
}
|
||||
@@ -102,13 +130,13 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
return "Python execution completed";
|
||||
}
|
||||
return "Python execution";
|
||||
}, [isComplete, isExecuting, hasError]);
|
||||
}, [isStreaming, isComplete, isExecuting, hasError]);
|
||||
|
||||
// Shared content for all states - used by both FULL and compact modes
|
||||
const content = (
|
||||
<div className="flex flex-col mb-1 space-y-2">
|
||||
{/* Loading indicator when executing */}
|
||||
{isExecuting && (
|
||||
{/* Loading indicator when streaming or executing */}
|
||||
{(isStreaming || isExecuting) && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<div className="flex gap-0.5">
|
||||
<div className="w-1 h-1 bg-current rounded-full animate-pulse"></div>
|
||||
@@ -121,7 +149,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
style={{ animationDelay: "0.2s" }}
|
||||
></div>
|
||||
</div>
|
||||
<span>Running code...</span>
|
||||
<span>{isStreaming ? "Writing code..." : "Running code..."}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ export function isToolPacket(
|
||||
PacketType.SEARCH_TOOL_DOCUMENTS_DELTA,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_DELTA,
|
||||
PacketType.TOOL_CALL_ARGUMENT_DELTA,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.CUSTOM_TOOL_DELTA,
|
||||
PacketType.FILE_READER_START,
|
||||
|
||||
@@ -27,6 +27,9 @@ export enum PacketType {
|
||||
FETCH_TOOL_URLS = "open_url_urls",
|
||||
FETCH_TOOL_DOCUMENTS = "open_url_documents",
|
||||
|
||||
// Tool call argument delta (streams tool args before tool executes)
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta",
|
||||
|
||||
// Custom tool packets
|
||||
CUSTOM_TOOL_START = "custom_tool_start",
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta",
|
||||
@@ -59,6 +62,10 @@ export enum PacketType {
|
||||
INTERMEDIATE_REPORT_CITED_DOCS = "intermediate_report_cited_docs",
|
||||
}
|
||||
|
||||
export const CODE_INTERPRETER_TOOL_TYPES = {
|
||||
PYTHON: "python",
|
||||
} as const;
|
||||
|
||||
// Basic Message Packets
|
||||
export interface MessageStart extends BaseObj {
|
||||
id: string;
|
||||
@@ -149,6 +156,13 @@ export interface PythonToolDelta extends BaseObj {
|
||||
file_ids: string[];
|
||||
}
|
||||
|
||||
export interface ToolCallArgumentDelta extends BaseObj {
|
||||
type: "tool_call_argument_delta";
|
||||
tool_type: string;
|
||||
tool_id: string;
|
||||
argument_deltas: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface FetchToolStart extends BaseObj {
|
||||
type: "open_url_start";
|
||||
}
|
||||
@@ -294,6 +308,7 @@ export type ImageGenerationToolObj =
|
||||
export type PythonToolObj =
|
||||
| PythonToolStart
|
||||
| PythonToolDelta
|
||||
| ToolCallArgumentDelta
|
||||
| SectionEnd
|
||||
| PacketError;
|
||||
export type FetchToolObj =
|
||||
|
||||
@@ -22,9 +22,6 @@ describe("Email/Password Login Workflow", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
// Mock window.location.href for redirect testing
|
||||
delete (window as any).location;
|
||||
window.location = { href: "" } as any;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -53,9 +50,9 @@ describe("Email/Password Login Workflow", () => {
|
||||
const loginButton = screen.getByRole("button", { name: /sign in/i });
|
||||
await user.click(loginButton);
|
||||
|
||||
// After successful login, user should be redirected to /chat
|
||||
// Verify success message is shown after login
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toBe("/app");
|
||||
expect(screen.getByText(/signed in successfully\./i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify API was called with correct credentials
|
||||
@@ -114,9 +111,6 @@ describe("Email/Password Signup Workflow", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
// Mock window.location.href
|
||||
delete (window as any).location;
|
||||
window.location = { href: "" } as any;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -39,7 +39,7 @@ export function ToggleWarningModal({
|
||||
{/* Message */}
|
||||
<div className="flex justify-center">
|
||||
<Text mainUiBody text04 className="text-center">
|
||||
We recommend using <strong>Claude Opus 4.5</strong> for Crafting.
|
||||
We recommend using <strong>Claude Opus 4.6</strong> for Crafting.
|
||||
<br />
|
||||
Other models may have reduced capabilities for code creation,
|
||||
<br />
|
||||
|
||||
@@ -34,8 +34,8 @@ export const PROVIDERS: ProviderConfig[] = [
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-5", label: "Claude Opus 4.5", recommended: true },
|
||||
{ name: "claude-sonnet-4-5", label: "Claude Sonnet 4.5" },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
apiKeyUrl: "https://console.anthropic.com/dashboard",
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
export interface BuildLlmSelection {
|
||||
providerName: string; // e.g., "build-mode-anthropic" (LLMProviderDescriptor.name)
|
||||
provider: string; // e.g., "anthropic"
|
||||
modelName: string; // e.g., "claude-opus-4-5"
|
||||
modelName: string; // e.g., "claude-opus-4-6"
|
||||
}
|
||||
|
||||
// Priority order for smart default LLM selection
|
||||
const LLM_SELECTION_PRIORITY = [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-5" },
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
|
||||
] as const;
|
||||
@@ -63,11 +63,11 @@ export function getDefaultLlmSelection(
|
||||
export const RECOMMENDED_BUILD_MODELS = {
|
||||
preferred: {
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-5",
|
||||
displayName: "Claude Opus 4.5",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
},
|
||||
alternatives: [
|
||||
{ provider: "anthropic", modelName: "claude-sonnet-4-5" },
|
||||
{ provider: "anthropic", modelName: "claude-sonnet-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openai", modelName: "gpt-5.1-codex" },
|
||||
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
|
||||
@@ -148,8 +148,8 @@ export const BUILD_MODE_PROVIDERS: BuildModeProvider[] = [
|
||||
providerName: "anthropic",
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-5", label: "Claude Opus 4.5", recommended: true },
|
||||
{ name: "claude-sonnet-4-5", label: "Claude Sonnet 4.5" },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
apiKeyUrl: "https://console.anthropic.com/dashboard",
|
||||
|
||||
21
web/src/lib/llmConfig/cache.ts
Normal file
21
web/src/lib/llmConfig/cache.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { ScopedMutator } from "swr";
|
||||
import {
|
||||
LLM_CHAT_PROVIDERS_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
|
||||
const PERSONA_PROVIDER_ENDPOINT_PATTERN =
|
||||
/^\/api\/llm\/persona\/\d+\/providers$/;
|
||||
|
||||
export async function refreshLlmProviderCaches(
|
||||
mutate: ScopedMutator
|
||||
): Promise<void> {
|
||||
await Promise.all([
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL),
|
||||
mutate(LLM_CHAT_PROVIDERS_URL),
|
||||
mutate(
|
||||
(key) =>
|
||||
typeof key === "string" && PERSONA_PROVIDER_ENDPOINT_PATTERN.test(key)
|
||||
),
|
||||
]);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
export const LLM_ADMIN_URL = "/api/admin/llm";
|
||||
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
|
||||
export const LLM_CHAT_PROVIDERS_URL = "/api/llm/provider";
|
||||
|
||||
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
|
||||
"/api/admin/llm/provider-contextual-cost";
|
||||
|
||||
@@ -193,7 +193,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -218,6 +218,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
chatSessionId: currentChatSessionId,
|
||||
agentId: selectedAgent?.id,
|
||||
});
|
||||
const deepResearchEnabledForCurrentWorkflow =
|
||||
currentProjectId === null && deepResearchEnabled;
|
||||
|
||||
const [presentingDocument, setPresentingDocument] =
|
||||
useState<MinimalOnyxDocument | null>(null);
|
||||
@@ -435,10 +437,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
|
||||
}, [
|
||||
messageHistory,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
]);
|
||||
|
||||
const toggleDocumentSidebar = useCallback(() => {
|
||||
if (!documentSidebarVisible) {
|
||||
@@ -458,7 +465,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
});
|
||||
if (showOnboarding || !onboardingDismissed) {
|
||||
finishOnboarding();
|
||||
@@ -468,7 +475,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
resetInputBar,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
@@ -503,7 +510,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
});
|
||||
if (showOnboarding || !onboardingDismissed) {
|
||||
finishOnboarding();
|
||||
@@ -524,7 +531,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
resetInputBar,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
@@ -709,7 +716,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
>
|
||||
{/* Main content grid — 3 rows, animated */}
|
||||
<div
|
||||
className="flex-1 w-full grid min-h-0 transition-[grid-template-rows] duration-150 ease-in-out"
|
||||
className="flex-1 w-full grid min-h-0 px-4 transition-[grid-template-rows] duration-150 ease-in-out"
|
||||
style={gridStyle}
|
||||
>
|
||||
{/* ── Top row: ChatUI / WelcomeMessage / ProjectUI ── */}
|
||||
@@ -732,7 +739,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<ChatUI
|
||||
liveAgent={liveAgent!}
|
||||
llmManager={llmManager}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
deepResearchEnabled={
|
||||
deepResearchEnabledForCurrentWorkflow
|
||||
}
|
||||
currentMessageFiles={currentMessageFiles}
|
||||
setPresentingDocument={setPresentingDocument}
|
||||
onSubmit={onSubmit}
|
||||
@@ -828,7 +837,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
/>
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
deepResearchEnabled={
|
||||
deepResearchEnabledForCurrentWorkflow
|
||||
}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
getProviderIcon,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { deleteLlmProvider, setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Horizontal as HorizontalInput } from "@/layouts/input-layouts";
|
||||
@@ -33,7 +34,6 @@ import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { getModalForExistingProvider } from "@/sections/modals/llmConfig/getModal";
|
||||
import { OpenAIModal } from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import { AnthropicModal } from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
@@ -140,7 +140,7 @@ function ExistingProviderCard({
|
||||
const handleDelete = async () => {
|
||||
try {
|
||||
await deleteLlmProvider(provider.id);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
deleteModal.toggle(false);
|
||||
toast.success("Provider deleted successfully!");
|
||||
} catch (e) {
|
||||
@@ -345,7 +345,7 @@ export default function LLMConfigurationPage() {
|
||||
|
||||
try {
|
||||
await setDefaultLlmModel(providerId, modelName);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success("Default model updated successfully!");
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
|
||||
@@ -148,7 +148,7 @@ const AppInputBar = React.memo(
|
||||
classification === "search";
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useForcedTools();
|
||||
const { currentMessageFiles, setCurrentMessageFiles } =
|
||||
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =
|
||||
useProjectsContext();
|
||||
|
||||
const currentIndexingFiles = useMemo(() => {
|
||||
@@ -200,9 +200,17 @@ const AppInputBar = React.memo(
|
||||
const textarea = textAreaRef.current;
|
||||
if (!wrapper || !textarea) return;
|
||||
|
||||
// Reset so scrollHeight reflects actual content size
|
||||
wrapper.style.height = `${MIN_INPUT_HEIGHT}px`;
|
||||
|
||||
// scrollHeight doesn't include the wrapper's padding, so add it back
|
||||
const wrapperStyle = getComputedStyle(wrapper);
|
||||
const paddingTop = parseFloat(wrapperStyle.paddingTop);
|
||||
const paddingBottom = parseFloat(wrapperStyle.paddingBottom);
|
||||
const contentHeight = textarea.scrollHeight + paddingTop + paddingBottom;
|
||||
|
||||
wrapper.style.height = `${Math.min(
|
||||
Math.max(textarea.scrollHeight, MIN_INPUT_HEIGHT),
|
||||
Math.max(contentHeight, MIN_INPUT_HEIGHT),
|
||||
MAX_INPUT_HEIGHT
|
||||
)}px`;
|
||||
}, [message, isSearchMode]);
|
||||
@@ -358,13 +366,19 @@ const AppInputBar = React.memo(
|
||||
const showDeepResearch = useMemo(() => {
|
||||
const deepResearchGloballyEnabled =
|
||||
combinedSettings?.settings?.deep_research_enabled ?? true;
|
||||
const isProjectWorkflow = currentProjectId !== null;
|
||||
|
||||
// TODO(@yuhong): Re-enable Deep Research in Projects workflow once it is fully supported.
|
||||
// https://linear.app/onyx-app/issue/ENG-3818/re-enable-deep-research-in-projects
|
||||
return (
|
||||
!isProjectWorkflow &&
|
||||
deepResearchGloballyEnabled &&
|
||||
hasSearchToolsAvailable(selectedAgent?.tools || [])
|
||||
);
|
||||
}, [
|
||||
selectedAgent?.tools,
|
||||
combinedSettings?.settings?.deep_research_enabled,
|
||||
currentProjectId,
|
||||
]);
|
||||
|
||||
function handleKeyDownForPromptShortcuts(
|
||||
|
||||
@@ -32,6 +32,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ScopedMutator } from "swr";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
@@ -83,7 +84,7 @@ interface BedrockModalInternalsProps {
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: (key: string) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
|
||||
@@ -164,6 +164,18 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
|
||||
// Verify SWR cache was invalidated
|
||||
expect(mockMutate).toHaveBeenCalledWith("/api/admin/llm/provider");
|
||||
expect(mockMutate).toHaveBeenCalledWith("/api/llm/provider");
|
||||
|
||||
const personaProvidersMutateCall = mockMutate.mock.calls.find(
|
||||
([key]) => typeof key === "function"
|
||||
);
|
||||
expect(personaProvidersMutateCall).toBeDefined();
|
||||
|
||||
const personaProviderFilter = personaProvidersMutateCall?.[0] as (
|
||||
key: unknown
|
||||
) => boolean;
|
||||
expect(personaProviderFilter("/api/llm/persona/42/providers")).toBe(true);
|
||||
expect(personaProviderFilter("/api/llm/provider")).toBe(false);
|
||||
});
|
||||
|
||||
test("shows error when test configuration fails", async () => {
|
||||
|
||||
@@ -27,6 +27,7 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { ScopedMutator } from "swr";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
@@ -46,7 +47,7 @@ interface LMStudioFormContentProps {
|
||||
setHasFetched: (value: boolean) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { ScopedMutator } from "swr";
|
||||
|
||||
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
@@ -44,7 +45,7 @@ interface OllamaModalContentProps {
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
}
|
||||
|
||||
@@ -4,14 +4,15 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { SvgTrash } from "@opal/icons";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { deleteLlmProvider } from "@/lib/llmConfig/svc";
|
||||
import { ScopedMutator } from "swr";
|
||||
|
||||
interface FormActionButtonsProps {
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
mutate: (key: string) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
}
|
||||
@@ -29,7 +30,7 @@ export function FormActionButtons({
|
||||
|
||||
try {
|
||||
await deleteLlmProvider(existingLlmProvider.id);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
onClose();
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, ReactNode } from "react";
|
||||
import useSWR, { useSWRConfig, KeyedMutator } from "swr";
|
||||
import useSWR, { useSWRConfig, ScopedMutator } from "swr";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
LLMProviderView,
|
||||
@@ -14,12 +14,12 @@ import { Button } from "@opal/components";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
|
||||
export interface ProviderFormContext {
|
||||
onClose: () => void;
|
||||
mutate: KeyedMutator<any>;
|
||||
mutate: ScopedMutator;
|
||||
isTesting: boolean;
|
||||
setIsTesting: (testing: boolean) => void;
|
||||
testError: string;
|
||||
@@ -95,7 +95,7 @@ export function ProviderFormEntrypointWrapper({
|
||||
|
||||
try {
|
||||
await setDefaultLlmModel(existingLlmProvider.id, firstVisibleModel.name);
|
||||
await mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success("Provider set as default successfully!");
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
|
||||
@@ -7,9 +7,11 @@ import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { ScopedMutator } from "swr";
|
||||
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
@@ -105,7 +107,7 @@ export interface SubmitLLMProviderParams<
|
||||
hideSuccess?: boolean;
|
||||
setIsTesting: (testing: boolean) => void;
|
||||
setTestError: (error: string) => void;
|
||||
mutate: (key: string) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
@@ -287,7 +289,7 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
onClose();
|
||||
|
||||
if (!hideSuccess) {
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
createMockOnboardingActions,
|
||||
createMockFetchResponses,
|
||||
MOCK_PROVIDERS,
|
||||
ANTHROPIC_DEFAULT_VISIBLE_MODELS,
|
||||
} from "./testHelpers";
|
||||
|
||||
// Mock fetch
|
||||
@@ -51,8 +50,6 @@ jest.mock("@/components/modals/ProviderModal", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock fetchModels utility - returns the curated Anthropic visible models
|
||||
// that match ANTHROPIC_VISIBLE_MODEL_NAMES from backend
|
||||
const mockFetchModels = jest.fn().mockResolvedValue({
|
||||
models: [
|
||||
{
|
||||
@@ -152,71 +149,6 @@ describe("AnthropicOnboardingForm", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Default Available Models", () => {
|
||||
/**
|
||||
* This test verifies that the exact curated list of Anthropic visible models
|
||||
* matches what's returned from /api/admin/llm/built-in/options.
|
||||
* The expected models are defined in ANTHROPIC_VISIBLE_MODEL_NAMES in
|
||||
* backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
test("llmDescriptor contains the correct default visible models from built-in options", () => {
|
||||
const expectedModelNames = [
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
];
|
||||
|
||||
// Verify MOCK_PROVIDERS.anthropic has the correct model configurations
|
||||
const actualModelNames = MOCK_PROVIDERS.anthropic.known_models.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
// Check that all expected models are present
|
||||
expect(actualModelNames).toEqual(
|
||||
expect.arrayContaining(expectedModelNames)
|
||||
);
|
||||
|
||||
// Check that only the expected models are present (no extras)
|
||||
expect(actualModelNames).toHaveLength(expectedModelNames.length);
|
||||
|
||||
// Verify each model has is_visible set to true
|
||||
MOCK_PROVIDERS.anthropic.known_models.forEach((config) => {
|
||||
expect(config.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("ANTHROPIC_DEFAULT_VISIBLE_MODELS matches backend ANTHROPIC_VISIBLE_MODEL_NAMES", () => {
|
||||
// These are the exact model names from backend/onyx/llm/llm_provider_options.py
|
||||
// ANTHROPIC_VISIBLE_MODEL_NAMES = {"claude-opus-4-5", "claude-sonnet-4-5", "claude-haiku-4-5"}
|
||||
const backendVisibleModelNames = new Set([
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
]);
|
||||
|
||||
const testHelperModelNames = new Set(
|
||||
ANTHROPIC_DEFAULT_VISIBLE_MODELS.map((m) => m.name)
|
||||
);
|
||||
|
||||
expect(testHelperModelNames).toEqual(backendVisibleModelNames);
|
||||
});
|
||||
|
||||
test("all default models are marked as visible", () => {
|
||||
ANTHROPIC_DEFAULT_VISIBLE_MODELS.forEach((model) => {
|
||||
expect(model.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("default model claude-sonnet-4-5 is set correctly in component", () => {
|
||||
// The AnthropicOnboardingForm sets DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5"
|
||||
// Verify this model exists in the default visible models
|
||||
const defaultModelExists = ANTHROPIC_DEFAULT_VISIBLE_MODELS.some(
|
||||
(m) => m.name === "claude-sonnet-4-5"
|
||||
);
|
||||
expect(defaultModelExists).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Validation", () => {
|
||||
test("submit button is disabled when form is empty", () => {
|
||||
render(<AnthropicOnboardingForm {...defaultProps} />);
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
createMockOnboardingActions,
|
||||
createMockFetchResponses,
|
||||
MOCK_PROVIDERS,
|
||||
OPENAI_DEFAULT_VISIBLE_MODELS,
|
||||
} from "./testHelpers";
|
||||
|
||||
// Mock fetch
|
||||
@@ -54,8 +53,6 @@ jest.mock("@/components/modals/ProviderModal", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock fetchModels utility - returns the curated OpenAI visible models
|
||||
// that match OPENAI_VISIBLE_MODEL_NAMES from backend
|
||||
const mockFetchModels = jest.fn().mockResolvedValue({
|
||||
models: [
|
||||
{
|
||||
@@ -173,77 +170,6 @@ describe("OpenAIOnboardingForm", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Default Available Models", () => {
|
||||
/**
|
||||
* This test verifies that the exact curated list of OpenAI visible models
|
||||
* matches what's returned from /api/admin/llm/built-in/options.
|
||||
* The expected models are defined in OPENAI_VISIBLE_MODEL_NAMES in
|
||||
* backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
test("llmDescriptor contains the correct default visible models from built-in options", () => {
|
||||
const expectedModelNames = [
|
||||
"gpt-5.2",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
];
|
||||
|
||||
// Verify MOCK_PROVIDERS.openai has the correct model configurations
|
||||
const actualModelNames = MOCK_PROVIDERS.openai.known_models.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
// Check that all expected models are present
|
||||
expect(actualModelNames).toEqual(
|
||||
expect.arrayContaining(expectedModelNames)
|
||||
);
|
||||
|
||||
// Check that only the expected models are present (no extras)
|
||||
expect(actualModelNames).toHaveLength(expectedModelNames.length);
|
||||
|
||||
// Verify each model has is_visible set to true
|
||||
MOCK_PROVIDERS.openai.known_models.forEach((config) => {
|
||||
expect(config.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("OPENAI_DEFAULT_VISIBLE_MODELS matches backend OPENAI_VISIBLE_MODEL_NAMES", () => {
|
||||
// These are the exact model names from backend/onyx/llm/llm_provider_options.py
|
||||
// OPENAI_VISIBLE_MODEL_NAMES = {"gpt-5.2", "gpt-5-mini", "o1", "o3-mini", "gpt-4o", "gpt-4o-mini"}
|
||||
const backendVisibleModelNames = new Set([
|
||||
"gpt-5.2",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
]);
|
||||
|
||||
const testHelperModelNames = new Set(
|
||||
OPENAI_DEFAULT_VISIBLE_MODELS.map((m) => m.name)
|
||||
);
|
||||
|
||||
expect(testHelperModelNames).toEqual(backendVisibleModelNames);
|
||||
});
|
||||
|
||||
test("all default models are marked as visible", () => {
|
||||
OPENAI_DEFAULT_VISIBLE_MODELS.forEach((model) => {
|
||||
expect(model.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("default model gpt-5.2 is set correctly in component", () => {
|
||||
// The OpenAIOnboardingForm sets DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2"
|
||||
// Verify this model exists in the default visible models
|
||||
const defaultModelExists = OPENAI_DEFAULT_VISIBLE_MODELS.some(
|
||||
(m) => m.name === "gpt-5.2"
|
||||
);
|
||||
expect(defaultModelExists).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Validation", () => {
|
||||
test("submit button is disabled when form is empty", () => {
|
||||
render(<OpenAIOnboardingForm {...defaultProps} />);
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
createMockOnboardingActions,
|
||||
createMockFetchResponses,
|
||||
MOCK_PROVIDERS,
|
||||
VERTEXAI_DEFAULT_VISIBLE_MODELS,
|
||||
} from "./testHelpers";
|
||||
|
||||
// Mock fetch
|
||||
@@ -51,8 +50,6 @@ jest.mock("@/components/modals/ProviderModal", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock fetchModels utility - returns the curated Vertex AI visible models
|
||||
// that match VERTEXAI_VISIBLE_MODEL_NAMES from backend
|
||||
jest.mock("@/app/admin/configuration/llm/utils", () => ({
|
||||
canProviderFetchModels: jest.fn().mockReturnValue(true),
|
||||
fetchModels: jest.fn().mockResolvedValue({
|
||||
@@ -154,71 +151,6 @@ describe("VertexAIOnboardingForm", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Default Available Models", () => {
|
||||
/**
|
||||
* This test verifies that the exact curated list of Vertex AI visible models
|
||||
* matches what's returned from /api/admin/llm/built-in/options.
|
||||
* The expected models are defined in VERTEXAI_VISIBLE_MODEL_NAMES in
|
||||
* backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
test("llmDescriptor contains the correct default visible models from built-in options", () => {
|
||||
const expectedModelNames = [
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
];
|
||||
|
||||
// Verify MOCK_PROVIDERS.vertexAi has the correct model configurations
|
||||
const actualModelNames = MOCK_PROVIDERS.vertexAi.known_models.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
// Check that all expected models are present
|
||||
expect(actualModelNames).toEqual(
|
||||
expect.arrayContaining(expectedModelNames)
|
||||
);
|
||||
|
||||
// Check that only the expected models are present (no extras)
|
||||
expect(actualModelNames).toHaveLength(expectedModelNames.length);
|
||||
|
||||
// Verify each model has is_visible set to true
|
||||
MOCK_PROVIDERS.vertexAi.known_models.forEach((config) => {
|
||||
expect(config.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("VERTEXAI_DEFAULT_VISIBLE_MODELS matches backend VERTEXAI_VISIBLE_MODEL_NAMES", () => {
|
||||
// These are the exact model names from backend/onyx/llm/llm_provider_options.py
|
||||
// VERTEXAI_VISIBLE_MODEL_NAMES = {"gemini-2.5-flash", "gemini-2.5-flash-lite", "gemini-2.5-pro"}
|
||||
const backendVisibleModelNames = new Set([
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
]);
|
||||
|
||||
const testHelperModelNames = new Set(
|
||||
VERTEXAI_DEFAULT_VISIBLE_MODELS.map((m) => m.name)
|
||||
);
|
||||
|
||||
expect(testHelperModelNames).toEqual(backendVisibleModelNames);
|
||||
});
|
||||
|
||||
test("all default models are marked as visible", () => {
|
||||
VERTEXAI_DEFAULT_VISIBLE_MODELS.forEach((model) => {
|
||||
expect(model.is_visible).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
test("default model gemini-2.5-pro is set correctly in component", () => {
|
||||
// The VertexAIOnboardingForm sets DEFAULT_DEFAULT_MODEL_NAME = "gemini-2.5-pro"
|
||||
// Verify this model exists in the default visible models
|
||||
const defaultModelExists = VERTEXAI_DEFAULT_VISIBLE_MODELS.some(
|
||||
(m) => m.name === "gemini-2.5-pro"
|
||||
);
|
||||
expect(defaultModelExists).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Validation", () => {
|
||||
test("submit button is disabled when form is empty", () => {
|
||||
render(<VertexAIOnboardingForm {...defaultProps} />);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
/**
|
||||
* Shared test helpers and mocks for onboarding form tests
|
||||
*/
|
||||
import React from "react";
|
||||
|
||||
// Mock Element.prototype.scrollIntoView for JSDOM (not implemented in jsdom)
|
||||
Element.prototype.scrollIntoView = jest.fn();
|
||||
@@ -161,11 +160,6 @@ export async function waitForModalOpen(screen: any, waitFor: any) {
|
||||
/**
|
||||
* Common provider descriptors for testing
|
||||
*/
|
||||
/**
|
||||
* The curated list of OpenAI visible models that are returned by
|
||||
* /api/admin/llm/built-in/options. This must match OPENAI_VISIBLE_MODEL_NAMES
|
||||
* in backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
export const OPENAI_DEFAULT_VISIBLE_MODELS = [
|
||||
{
|
||||
name: "gpt-5.2",
|
||||
@@ -211,11 +205,6 @@ export const OPENAI_DEFAULT_VISIBLE_MODELS = [
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* The curated list of Anthropic visible models that are returned by
|
||||
* /api/admin/llm/built-in/options. This must match ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
* in backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
export const ANTHROPIC_DEFAULT_VISIBLE_MODELS = [
|
||||
{
|
||||
name: "claude-opus-4-5",
|
||||
@@ -240,11 +229,6 @@ export const ANTHROPIC_DEFAULT_VISIBLE_MODELS = [
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* The curated list of Vertex AI visible models that are returned by
|
||||
* /api/admin/llm/built-in/options. This must match VERTEXAI_VISIBLE_MODEL_NAMES
|
||||
* in backend/onyx/llm/llm_provider_options.py
|
||||
*/
|
||||
export const VERTEXAI_DEFAULT_VISIBLE_MODELS = [
|
||||
{
|
||||
name: "gemini-2.5-flash",
|
||||
|
||||
@@ -292,7 +292,10 @@ test.describe("Assistant Creation and Edit Verification", () => {
|
||||
expect(agentIdMatch).toBeTruthy();
|
||||
const agentId = agentIdMatch ? agentIdMatch[1] : null;
|
||||
expect(agentId).not.toBeNull();
|
||||
await expectScreenshot(page, { name: "welcome-page-with-assistant" });
|
||||
await expectScreenshot(page, {
|
||||
name: "welcome-page-with-assistant",
|
||||
hide: ["[data-testid='AppInputBar/llm-popover-trigger']"],
|
||||
});
|
||||
|
||||
// Store assistant ID for cleanup
|
||||
knowledgeAssistantId = Number(agentId);
|
||||
|
||||
Reference in New Issue
Block a user