Compare commits

..

27 Commits

Author SHA1 Message Date
Nik
061749bd0c fix(chat): address review comments and improve multi-model UX
Upstream fixes from 4a (selector/hook):
- Fix model_provider mapping: use m.provider instead of m.name
- Pass isLoading to ModelListContent for loading state
- Cap restoreFromModelNames at MAX_MODELS
- Remove redundant llmManager from useEffect deps
- Use Separator component instead of raw divs
- Add two-click-zone on SelectButton (main area→replace, X→remove)
- Anchor popover above the specific clicked pill via virtualRef

Upstream fixes from 4b (panels):
- Use responses.find() instead of array indexing by modelIndex
- Guard preferredIdx === -1 to prevent NaN in carousel transform
- Remove unused SvgEyeClosed import, isHighlighted field, modelIndex prop

Local (4c):
- Conditional fade: only show gradient when panel overflows preferred height
2026-04-02 03:20:24 -07:00
Nik
4adc97cd1b fix(chat): remove drag logic from carousel for text selectability
Remove all pointer-drag interaction code from MultiModelResponseView.
This restores text selection in carousel panels. Drag-to-switch can
be re-added later when needed.
2026-04-02 02:58:12 -07:00
Nik
27dbb47b96 fix(multi-model): replace React pointer handlers with window listeners for carousel drag
setPointerCapture + React synthetic event delegation conflict caused both
click-to-select and drag-scroll to break in selection mode. Replace
onPointerMove/onPointerUp React handlers with native window.addEventListener
closures created inside onPointerDown — window listeners fire regardless of
cursor position and don't require pointer capture.

- Removes dragStartX, baseTranslateX, dragCurrentDelta, isDraggingRef refs
- Adds dragCleanupRef for unmount-safe listener removal
- Adds e.preventDefault() in pointerdown to block text-selection drag
- Captures preferredIndex/responses/hiddenPanels at press time to avoid stale closures
2026-04-02 02:17:16 -07:00
Nik
da8c93b0c7 fix(chat): defer setPointerCapture to drag threshold to restore panel clicks
setPointerCapture on pointerdown was redirecting pointerup to the
container for plain clicks, causing the browser to fire click on the
container instead of the child panel — breaking panel selection. Now
capture is set only once the 5px drag threshold is crossed in
pointermove, so clicks pass through normally.
2026-04-02 02:17:16 -07:00
Nik
3b635d5878 feat(chat): add drag-scroll to selection mode carousel
Pointer drag on the carousel container moves the track directly via DOM
transform (no re-renders during drag). On release, if the drag exceeds
80px the adjacent visible panel becomes preferred and the existing
width+position animation snaps it into center; shorter drags snap back
to the current preferred panel. Clicks after a drag are suppressed via
justDraggedRef so they don't accidentally change the selection.
2026-04-02 02:17:16 -07:00
Nik
ab90cf1066 fix(chat): use pixel values for carousel centering transform
translateX(calc(50% - ...)) used 50% of the track's own width, not
the container's. Use measured trackContainerW for correct centering.
2026-04-02 02:17:15 -07:00
Nik
5de4f57c42 docs(chat): add JSDoc to MultiModelResponseView and MultiModelPanel 2026-04-02 02:17:15 -07:00
Nik
d6f3add517 feat(chat): add multi-model response panels
MultiModelResponseView renders 2-3 LLM responses side-by-side with
generation and selection modes. MultiModelPanel wraps AgentMessage for
each model's output. Adds hideFooter prop to AgentMessage for
non-preferred panel styling.
2026-04-02 02:17:15 -07:00
Nik
3ca304d228 fix(chat): move X icon inside SelectButton, use size md
Match Figma mock: rightIcon={SvgX} puts the remove icon inside the
pill. Size md makes pills larger. Clicking the pill removes the model.
2026-04-02 02:17:06 -07:00
Nik
a5c982fb75 fix(chat): anchor popover to pills, fix separator when at max models
Popover flew to upper-left at 3 models because no Trigger existed.
Wrap pills in Popover.Anchor so content positions correctly. Leading
separator now only shown when + button is visible.
2026-04-02 02:10:21 -07:00
Nik
2d3c838707 refactor(chat): extract ModelListContent, drop accordion for flat section headers
ModelSelector was reimplementing LLMPopover internals (search, accordion,
item rendering). Extract shared ModelListContent component that uses the
ResourcePopover pattern — flat section headers with provider icon + separator
line instead of collapsible accordions. ModelSelector now just handles
selection logic and pill rendering.
2026-04-02 02:02:10 -07:00
Nik
d936efeffd refactor(chat): reuse opal components in ModelSelector
Replace hand-rolled ModelPill with opal SelectButton, custom ModelItem
with LineItem, raw Radix AccordionPrimitive with shared Accordion, and
reuse buildLlmOptions/groupLlmOptions from LLMPopover.
2026-04-02 02:02:10 -07:00
Nik
e1e7693870 feat(chat): add multi-model selector and chat hook
ModelSelector component with add/remove/replace semantics for up to 3
concurrent LLMs. useMultiModelChat hook manages model selection state
and streaming coordination.
2026-04-02 02:02:10 -07:00
Nikolas Garza
df14bbe0e2 feat(chat): add frontend types and API helpers for multi-model streaming (#9648) 2026-04-02 08:52:21 +00:00
Nikolas Garza
3db1ad82ce feat(chat): add multi-model parallel streaming backend (#9647) 2026-04-02 08:03:54 +00:00
Yuhong Sun
1e7882529c docs: Chat README (#9853) 2026-04-02 00:47:31 -07:00
Raunak Bhagat
5d405cfa2d refactor: Hook Extensions edits (#9831) 2026-04-02 06:55:59 +00:00
Nikolas Garza
de3a253ea9 refactor(emitter): replace bus-polling with merge-queue (#9803) 2026-04-02 06:34:12 +00:00
Justin Tahara
d6946a66a5 fix(llm): Azure custom model support + Mistral tool call message ordering (#9729) 2026-04-02 06:22:13 +00:00
Raunak Bhagat
11835a0268 fix(opal): guard opal/interactive's onClick handlers against React portal event bubbling (#9850) 2026-04-02 05:37:00 +00:00
Danelegend
519fb61cc7 fix(xlsx): Improve empty row/col handling (#9288) 2026-04-02 02:52:20 +00:00
acaprau
02671937fb chore(opensearch): Increase DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES to 500, disable profiling by default (#9844)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-04-02 00:49:21 +00:00
acaprau
1466158c1e fix(opensearch): Add Vespa server-side timeout for the migration (#9843) 2026-04-02 00:20:59 +00:00
Bo-Onyx
073cf11c42 feat(hook): update hook doc link and reference (#9841) 2026-04-02 00:12:04 +00:00
Justin Tahara
a2b0c15027 fix(db): remove unnecessary selectinload(User.memories) from auth paths (#9838) 2026-04-01 23:51:06 +00:00
Raunak Bhagat
a462678ddd refactor(opal): split SelectCard's sizeVariant prop into paddingVariant + roundingVariant (#9830) 2026-04-01 23:15:20 +00:00
Bo-Onyx
c50d2739b8 feat(hook): integrate document ingestion hook point (#9810) 2026-04-01 23:08:12 +00:00
89 changed files with 5834 additions and 2665 deletions

View File

@@ -1,25 +1,33 @@
# Overview of Context Management
This document reviews some design decisions around the main agent-loop powering Onyx's chat flow.
It is highly recommended for all engineers contributing to this flow to be familiar with the concepts here.
> Note: it is assumed the reader is familiar with the Onyx product and features such as Projects, User files, Citations, etc.
## System Prompt
The system prompt is a default prompt that comes packaged with the system. Users can edit the default prompt and it will be persisted in the database.
Some parts of the system prompt are dynamically updated / inserted:
- Datetime of the message sent
- Tools description of when to use certain tools depending on if the tool is available in that cycle
- If the user has just called a search related tool, then a section about citations is included
## Custom Agent Prompt
The custom agent is inserted as a user message above the most recent user message, it is dynamically moved in the history as the user sends more messages.
If the user has opted to completely replace the System Prompt, then this Custom Agent prompt replaces the system prompt and does not move along the history.
## How Files are handled
On upload, Files are processed for tokens, if too many tokens to fit in the context, its considered a failed inclusion. This is done using the LLM tokenizer.
- In many cases, there is not a known tokenizer for each LLM so there is a default tokenizer used as a catchall.
- File upload happens in 2 parts - the actual upload + token counting.
- Files are added into chat context as a “point in time” inclusion and move up the context window as the conversation progresses.
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
Image files are attached to User Messages also as point in time inclusions.
@@ -27,8 +35,8 @@ Image files are attached to User Messages also as point in time inclusions.
Files selected from the search results are also counted as “point in time” inclusions. Files that are too large cannot be selected.
For these files, the "entire file" does not exist for most connectors, it's pieced back together from the search engine.
## Projects
If a Project contains few enough files that it all fits in the model context, we keep it close enough in the history to ensure it is easy for the LLM to
access. Note that the project documents are assumed to be quite useful and that they should 1. never be dropped from context, 2. is not just a needle in
a haystack type search with a strong keyword to make the LLM attend to it.
@@ -36,11 +44,12 @@ a haystack type search with a strong keyword to make the LLM attend to it.
Project files are vectorized and stored in the Search Engine so that if the user chooses a model with less context than the number of tokens in the project,
the system can RAG over the project files.
## How documents are represented
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix to make the
context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of message
rather than a user message.
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix string to
make the context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of
message rather than a user message.
```
Here are some documents provided for context, they may not all be relevant:
{
@@ -50,33 +59,37 @@ Here are some documents provided for context, they may not all be relevant:
]
}
```
Documents are represented with document so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
Documents are represented with the `document` key so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
translate this into links and other UI elements. What the LLM sees is far simpler to reduce noise/hallucinations.
Note that documents included in a single turn should be collapsed into a single user message.
Search tools give URLs to the LLM though so that open_url (a separate tool) can be called on them.
Search tools also give URLs to the LLM so that open_url (a separate tool) can be called on them.
## Reminders
To ensure the LLM follows certain specific instructions, instructions are added at the very end of the chat context as a user message. If a search related
tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the
final message. If a search related tool just ran and the user has reminders, both appear in a single message.
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded.
## Tool Calls
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are today replaced with a hardcoded
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are current replaced with a hardcoded
string saying it is no longer available. Tool Call details like the search query and other arguments are kept in the history as this is information
rich and generally very few tokens.
> Note: in the Internal Search flow with query expansion, the Tool Call which was actually run differs from what the LLM provided as arguments.
> What the LLM sees in the history (to be most informative for future calls) is the full set of expanded queries.
**Possible Future Extension**:
Instead of dropping the Tool Call response, we might summarize it using an LLM so that it is just 1-2 sentences and captures the main points. That said,
this is questionable value add because anything relevant and useful should be already captured in the Agent response.
## Examples
```
S -> System Message
CA -> Custom Agent as a User Message
@@ -98,15 +111,15 @@ Flow with Project and File Upload
S, CA, P, F, U1, A1 -- user sends another message -> S, F, U1, A1, CA, P, U2, A2
- File stays in place, above the user message
- Project files move along the chain as new messages are sent
- Custom Agent prompt comes before project files which comes before user uploaded files in each turn
- Custom Agent prompt comes before project files which come before user uploaded files in each turn
Reminders during a single Turn
S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
- Reminder moved to the end
```
## Product considerations
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
@@ -117,9 +130,9 @@ User Message further away. This tradeoff is accepted for Projects because of the
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
and should be very targetted for it to work reliably and also not interfere with the last user message.
## Reasons / Experiments
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrades performance of the system especially when the instructions
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
@@ -146,10 +159,10 @@ In a similar concept, LLM instructions in the system prompt are structured speci
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
rate even just moving the same statement a few sentences.
rate by even just moving the same statement a few sentences.
## Other related pointers
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
---
@@ -160,32 +173,38 @@ rate even just moving the same statement a few sentences.
Turn: User sends a message and AI does some set of things and responds
Step/Cycle: 1 single LLM inference given some context and some tools
## 1. Top Level (process_message function):
This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the
messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does
are:
- Validates the request
- Builds the chat history for the session
- Fetches any additional context such as files and images
- Prepares all of the tools for the LLM
- Creates the state container objects for use in the loop
### Wrapper (run_chat_loop_with_state_containers function):
This wrapper is used to run the LLM flow in a background thread and monitor the emitter for stop signals. This means the top
level is as isolated from the LLM flow as possible and can continue to yield packets as soon as they are available from the lower
levels. This also means that if the lower levels fail, the top level will still guarantee a reasonable response to the user.
All of the saving and database operations are abstracted away from the lower levels.
### Execution (`_run_models` function):
Each model runs in its own worker thread inside a `ThreadPoolExecutor`. Workers write packets to a shared
`merged_queue` via an `Emitter`; the main thread drains the queue and yields packets in arrival order. This
means the top level is isolated from the LLM flow and can yield packets as soon as they are produced. If a
worker fails, the main thread yields a `StreamingError` for that model and keeps the other models running.
All saving and database operations are handled by the main thread after the workers complete (or by the
workers themselves via self-completion if the drain loop exits early).
### Emitter
The emitter is designed to be an object queue so that lower levels do not need to yield objects all the way back to the top.
This way the functions can be better designed (not everything as a generator) and more easily tested. The wrapper around the
LLM flow (run_chat_loop_with_state_containers) is used to monitor the emitter and handle packets as soon as they are available
from the lower levels. Both the emitter and the state container are mutating state objects and only used to accumulate state.
There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take
packets and should not be used for other things.
The emitter is an object that lower levels use to send packets without needing to yield them all the way back
up the call stack. Each `Emitter` tags every packet with a `model_index` and places it on the shared
`merged_queue` as a `(model_idx, packet)` tuple. The drain loop in `_run_models` consumes these tuples and
yields the packets to the caller. Both the emitter and the state container are mutating state objects used
only to accumulate state. There should be no logic dependent on the states of these objects, especially in
the lower levels. The emitter should only take packets and should not be used for other things.
### State Container
The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic,
only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database.
So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once
@@ -193,35 +212,40 @@ the lower level is completed whether on its own or stopped by the user. At that
the database. The state container can be added to by any of the underlying layers, this is fine.
### Stopping Generation
A stop signal is checked every 300ms by the wrapper around the LLM flow. The signal itself
is stored in Redis and is set by the user calling the stop endpoint. The wrapper ensures that no matter what the lower level is
doing at the time, the thread can be killed by the top level. It does not require a cooperative cancellation from the lower level
and in fact the lower level does not know about the stop signal at all.
The drain loop in `_run_models` checks `check_is_connected()` every 50 ms (on queue timeout). The signal itself
is stored in Redis and is set by the user calling the stop endpoint. On disconnect, the drain loop saves
partial state for every model, yields an `OverallStop(stop_reason="user_cancelled")` packet, and returns.
A `drain_done` event signals emitters to stop blocking so worker threads can exit quickly. Workers that
already completed successfully will self-complete (persist their response) if the drain loop exited before
reaching the normal completion path.
## 2. LLM Loop (run_llm_loop function)
This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what
is outlined in the first half of this doc). Its main functionality is:
- Translate and truncate the context for the LLM inference
- Add context modifiers like reminders, updates to the system prompts, etc.
- Run tool calls and gather results
- Build some of the objects stored in the state container.
## 3. LLM Step (run_llm_step function)
This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations
so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they
do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different
tool calls and returns that to the LLM Loop to execute.
## Things to know
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
- There are 3 representations of a message, each scoped to a different layer:
1. **ChatMessage** — The database model. Should be converted into ChatMessageSimple early and never passed deep into the flow.
2. **ChatMessageSimple** — The canonical data model used throughout the codebase. This is the rich, full-featured representation
of a message. Any modifications or additions to message structure should be made here.
3. **LanguageModelInput** — The LLM-facing representation. Intentionally minimal so the LLM interface layer stays clean and
easy to maintain/extend.

View File

@@ -1,19 +1,28 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from dataclasses import dataclass
from uuid import UUID
from pydantic import BaseModel
from onyx.cache.interface import CacheBackend
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import SearchParams
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.db.memory import UserMemoryContext
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.tools.models import ChatFile
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -161,112 +170,45 @@ class ChatStateContainer:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
class AvailableFiles(BaseModel):
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
# IDs from the ``user_file`` table (project / persona-attached files).
user_file_ids: list[UUID] = []
# IDs from the ``file_record`` table (chat-attached files).
chat_file_ids: list[UUID] = []
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
@dataclass(frozen=True)
class ChatTurnSetup:
"""Immutable context produced by ``build_chat_turn`` and consumed by ``_run_models``."""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)
new_msg_req: SendMessageRequest
chat_session: ChatSession
persona: Persona
user_message: ChatMessage
user_identity: LLMUserIdentity
llms: list[LLM] # length 1 for single-model, N for multi-model
model_display_names: list[str] # parallel to llms
simple_chat_history: list[ChatMessageSimple]
extracted_context_files: ExtractedContextFiles
reserved_messages: list[ChatMessage] # length 1 for single, N for multi
reserved_token_count: int
search_params: SearchParams
all_injected_file_metadata: dict[str, FileToolMetadata]
available_files: AvailableFiles
tool_id_to_name_map: dict[int, str]
forced_tool_id: int | None
files: list[ChatLoadedFile]
chat_files_for_tools: list[ChatFile]
custom_agent_prompt: str | None
user_memory_context: UserMemoryContext
# For deep research: was the last assistant message a clarification request?
skip_clarification: bool
check_is_connected: Callable[[], bool]
cache: CacheBackend
# Execution params forwarded to per-model tool construction
bypass_acl: bool
slack_context: SlackContext | None
custom_tool_additional_headers: dict[str, str] | None
mcp_headers: dict[str, str] | None

View File

@@ -1,19 +1,40 @@
import threading
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
def __init__(self, bus: Queue):
self.bus = bus
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
drain_done: Optional event set by ``_run_models`` when the drain loop
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
immediately so worker threads can exit fast.
"""
def __init__(
self,
merged_queue: Queue[tuple[int, Packet | Exception | object]],
model_idx: int = 0,
drain_done: threading.Event | None = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
self._drain_done = drain_done
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
if self._drain_done is not None and self._drain_done.is_set():
return
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
self._merged_queue.put((self._model_idx, tagged))

File diff suppressed because it is too large Load Diff

View File

@@ -286,11 +286,9 @@ USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
# Profiling adds some overhead to OpenSearch operations. This overhead is
# unknown right now. It is enabled by default so we can get useful logs for
# investigating slow queries. We may never disable it if the overhead is
# minimal.
# unknown right now. Defaults to True.
OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "true").lower() == "true"
)
# Whether to disable match highlights for OpenSearch. Defaults to True for now
# as we investigate query performance.
@@ -942,9 +940,20 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
)
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
# This is the timeout for the client side of the Vespa migration task. When
# exceeded, an exception is raised in our code. This value should be higher than
# VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT.
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
)
# This is the timeout Vespa uses on the server side to know when to wrap up its
# traversal and try to report partial results. This differs from the client
# timeout above which raises an exception in our code when exceeded. This
# timeout allows Vespa to return gracefully. This value should be lower than
# VESPA_MIGRATION_REQUEST_TIMEOUT_S. Formatted as <number of seconds>s.
VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT = os.environ.get(
"VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT", "110s"
)
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")

View File

@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
from onyx.connectors.google_drive.file_retrieval import (
get_files_by_web_view_links_batch,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
@@ -73,13 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import NormalizationResult
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -207,9 +202,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1672,82 +1665,6 @@ class GoogleDriveConnector(
start, end, checkpoint, include_permissions=True
)
@override
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(f"Resolving {len(errors)} errors")
doc_ids = [
failure.failed_document.document_id
for failure in errors
if failure.failed_document
]
service = get_drive_service(self.creds, self.primary_admin_email)
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
for doc_id, error in batch_result.errors.items():
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=doc_id,
),
failure_message=f"Failed to retrieve file during error resolution: {error}",
exception=error,
)
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
)
retrieved_files = [
RetrievedDriveFile(
drive_file=file,
user_email=self.primary_admin_email,
completion_stage=DriveRetrievalStage.DONE,
)
for file in batch_result.files.values()
]
yield from self._get_new_ancestors_for_files(
files=retrieved_files,
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
permission_sync_context=permission_sync_context,
add_prefix=True,
)
func_with_args = [
(
self._convert_retrieved_file_to_document,
(rf, permission_sync_context),
)
for rf in retrieved_files
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
for result in results:
if result is not None:
yield result
def _extract_slim_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,

View File

@@ -9,7 +9,6 @@ from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
MAX_BATCH_SIZE = 100
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().list() based on the field type enum."""
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return FILE_FIELDS
def _extract_single_file_fields(list_fields: str) -> str:
"""Convert a files().list() fields string to one suitable for files().get().
List fields look like "nextPageToken, files(field1, field2, ...)"
Single-file fields should be just "field1, field2, ..."
"""
start = list_fields.find("files(")
if start == -1:
return list_fields
inner_start = start + len("files(")
inner_end = list_fields.rfind(")")
return list_fields[inner_start:inner_end]
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().get() based on the field type enum."""
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
def _get_files_in_parent(
service: Resource,
parent_id: str,
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
)
.execute()
)
class BatchRetrievalResult:
"""Result of a batch file retrieval, separating successes from errors."""
def __init__(self) -> None:
self.files: dict[str, GoogleDriveFileType] = {}
self.errors: dict[str, Exception] = {}
def get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
field_type: DriveFileFieldType,
) -> BatchRetrievalResult:
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
Returns a BatchRetrievalResult containing successful file retrievals
and errors for any files that could not be fetched.
Automatically splits into chunks of MAX_BATCH_SIZE.
"""
fields = _get_single_file_fields(field_type)
if len(web_view_links) <= MAX_BATCH_SIZE:
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
combined = BatchRetrievalResult()
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
combined.files.update(chunk_result.files)
combined.errors.update(chunk_result.errors)
return combined
def _get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
fields: str,
) -> BatchRetrievalResult:
"""Single-batch implementation."""
result = BatchRetrievalResult()
def callback(
request_id: str,
response: GoogleDriveFileType,
exception: Exception | None,
) -> None:
if exception:
logger.warning(f"Error retrieving file {request_id}: {exception}")
result.errors[request_id] = exception
else:
result.files[request_id] = response
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
batch.add(request, request_id=web_view_link)
except ValueError as e:
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
result.errors[web_view_link] = e
batch.execute()
return result

View File

@@ -298,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
raise NotImplementedError
class Resolver(BaseConnector):
@abc.abstractmethod
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
If include_permissions is True, the documents will have permissions synced.
May also yield HierarchyNode objects for ancestor folders of resolved documents.
"""
raise NotImplementedError
class HierarchyConnector(BaseConnector):
@abc.abstractmethod
def load_hierarchy(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -631,6 +631,91 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -853,6 +938,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -47,7 +46,6 @@ async def fetch_user_for_pat(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
)
if not user:
return None

View File

@@ -229,7 +229,9 @@ def get_memories_for_user(
user_id: UUID,
db_session: Session,
) -> Sequence[Memory]:
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
return db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
).all()
def update_user_pinned_assistants(

View File

@@ -37,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
# we have a much higher chance of all 10 of the final desired docs showing up
# and getting scored. In worse situations, the final 10 docs don't even show up
# as the final 10 (worse than just a miss at the reranking step).
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
# poor search performance.
# Defaults to 500 for now. Initially this defaulted to 750 but we were seeing
# poor search performance; bumped from 100 to 500 to improve recall.
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 500)
)
# Number of vectors to examine to decide the top k neighbors for the HNSW

View File

@@ -20,6 +20,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
from onyx.configs.app_configs import VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.document_index.interfaces import VespaChunkRequest
@@ -335,6 +336,11 @@ def get_all_chunks_paginated(
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
# When exceeded, Vespa should return gracefully with partial
# results. Even if no hits are returned, Vespa should still return a
# new continuation token representing a new spot in the linear
# traversal.
"timeout": VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT,
}
if continuation_token is not None:
params["continuation"] = continuation_token
@@ -343,6 +349,9 @@ def get_all_chunks_paginated(
start_time = time.monotonic()
try:
with get_vespa_http_client(
# When exceeded, an exception is raised in our code. No progress
# is saved, and the task will retry this spot in the traversal
# later.
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as http_client:
response = http_client.get(url, params=params)

View File

@@ -1,3 +1,4 @@
import csv
import gc
import io
import json
@@ -19,6 +20,7 @@ from zipfile import BadZipFile
import chardet
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from PIL import Image
from onyx.configs.constants import ONYX_METADATA_FILENAME
@@ -353,6 +355,94 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
return presentation.markdown
def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values
"""
rows: list[list[str]] = []
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
rows.append(row)
return rows
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
"""
Cleans a worksheet matrix by removing rows if there are N consecutive empty
rows and removing cols if there are M consecutive empty columns
"""
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
MAX_EMPTY_COLS = 2
# Row cleanup
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
if not matrix:
return matrix
# Column cleanup — determine which columns to keep without transposing.
num_cols = len(matrix[0])
keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS)
if len(keep_cols) < num_cols:
matrix = [[row[c] for c in keep_cols] for row in matrix]
return matrix
def _columns_to_keep(
matrix: list[list[str]], num_cols: int, max_empty: int
) -> list[int]:
"""Return the indices of columns to keep after removing empty-column runs.
Uses the same logic as ``_remove_empty_runs`` but operates on column
indices so no transpose is needed.
"""
kept: list[int] = []
empty_buffer: list[int] = []
for col_idx in range(num_cols):
col_is_empty = all(not row[col_idx] for row in matrix)
if col_is_empty:
empty_buffer.append(col_idx)
else:
kept.extend(empty_buffer[:max_empty])
kept.append(col_idx)
empty_buffer = []
return kept
def _remove_empty_runs(
rows: list[list[str]],
max_empty: int,
) -> list[list[str]]:
"""Removes entire runs of empty rows when the run length exceeds max_empty.
Leading empty runs are capped to max_empty, just like interior runs.
Trailing empty rows are always dropped since there is no subsequent
non-empty row to flush them.
"""
result: list[list[str]] = []
empty_buffer: list[list[str]] = []
for row in rows:
# Check if empty
if not any(row):
if len(empty_buffer) < max_empty:
empty_buffer.append(row)
else:
# Add upto max empty rows onto the result - that's what we allow
result.extend(empty_buffer[:max_empty])
# Add the new non-empty row
result.append(row)
empty_buffer = []
return result
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
# TODO: switch back to this approach in a few months when markitdown
# fixes their handling of excel files
@@ -391,30 +481,15 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise e
raise
text_content = []
for sheet in workbook.worksheets:
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
text_content.append(buf.getvalue().rstrip("\n"))
return TEXT_SECTION_SEPARATOR.join(text_content)

View File

@@ -1,33 +1,114 @@
from pydantic import BaseModel
from pydantic import Field
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionSection(BaseModel):
"""Represents a single section of a document — either text or image, not both.
Text section: set `text`, leave `image_file_id` null.
Image section: set `image_file_id`, leave `text` null.
"""
text: str | None = Field(
default=None,
description="Text content of this section. Set for text sections, null for image sections.",
)
link: str | None = Field(
default=None,
description="Optional URL associated with this section. Preserve the original link from the payload if you want it retained.",
)
image_file_id: str | None = Field(
default=None,
description=(
"Opaque identifier for an image stored in the file store. "
"The image content is not included — this field signals that the section is an image. "
"Hooks can use its presence to reorder or drop image sections, but cannot read or modify the image itself."
),
)
class DocumentIngestionOwner(BaseModel):
display_name: str | None = Field(
default=None,
description="Human-readable name of the owner.",
)
email: str | None = Field(
default=None,
description="Email address of the owner.",
)
class DocumentIngestionPayload(BaseModel):
pass
document_id: str = Field(
description="Unique identifier for the document. Read-only — changes are ignored."
)
title: str | None = Field(description="Title of the document.")
semantic_identifier: str = Field(
description="Human-readable identifier used for display (e.g. file name, page title)."
)
source: str = Field(
description=(
"Connector source type (e.g. confluence, slack, google_drive). "
"Read-only — changes are ignored. "
"Full list of values: https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/configs/constants.py#L195"
)
)
sections: list[DocumentIngestionSection] = Field(
description="Sections of the document. Includes both text sections (text set, image_file_id null) and image sections (image_file_id set, text null)."
)
metadata: dict[str, list[str]] = Field(
description="Key-value metadata attached to the document. Values are always a list of strings."
)
doc_updated_at: str | None = Field(
description="ISO 8601 UTC timestamp of the last update at the source, or null if unknown. Example: '2024-03-15T10:30:00+00:00'."
)
primary_owners: list[DocumentIngestionOwner] | None = Field(
description="Primary owners of the document, or null if not available."
)
secondary_owners: list[DocumentIngestionOwner] | None = Field(
description="Secondary owners of the document, or null if not available."
)
class DocumentIngestionResponse(BaseModel):
pass
# Intentionally permissive — customer endpoints may return extra fields.
sections: list[DocumentIngestionSection] | None = Field(
description="The sections to index, in the desired order. Reorder, drop, or modify sections freely. Null or empty list drops the document."
)
rejection_reason: str | None = Field(
default=None,
description="Logged when sections is null or empty. Falls back to a generic message if omitted.",
)
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
"""Hook point that runs on every document before it enters the indexing pipeline.
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
Call site: immediately after Onyx's internal validation and before the
indexing pipeline begins — no partial writes have occurred yet.
If a Document Ingestion hook is configured, it takes precedence —
Document Ingestion Light will not run. Configure only one per deployment.
Supported use cases:
- Document filtering: drop documents based on content or metadata
- Content rewriting: redact PII or normalize text before indexing
"""
hook_point = HookPoint.DOCUMENT_INGESTION
display_name = "Document Ingestion"
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
description = (
"Runs on every document before it enters the indexing pipeline. "
"Allows filtering, rewriting, or dropping documents."
)
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
docs_url = "https://docs.onyx.app/admins/advanced_configs/hook_extensions#document-ingestion"
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -65,8 +65,9 @@ class QueryProcessingSpec(HookPointSpec):
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
docs_url = (
"https://docs.onyx.app/admins/advanced_configs/hook_extensions#query-processing"
)
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -33,6 +33,7 @@ from onyx.connectors.models import TextSection
from onyx.db.document import get_documents_by_ids
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.enums import HookPoint
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
@@ -47,6 +48,13 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.document_ingestion import DocumentIngestionOwner
from onyx.hooks.points.document_ingestion import DocumentIngestionPayload
from onyx.hooks.points.document_ingestion import DocumentIngestionResponse
from onyx.hooks.points.document_ingestion import DocumentIngestionSection
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
@@ -297,6 +305,7 @@ def index_doc_batch_with_handler(
document_batch: list[Document],
request_id: str | None,
tenant_id: str,
db_session: Session,
adapter: IndexingBatchAdapter,
ignore_time_skip: bool = False,
enable_contextual_rag: bool = False,
@@ -310,6 +319,7 @@ def index_doc_batch_with_handler(
document_batch=document_batch,
request_id=request_id,
tenant_id=tenant_id,
db_session=db_session,
adapter=adapter,
ignore_time_skip=ignore_time_skip,
enable_contextual_rag=enable_contextual_rag,
@@ -785,6 +795,132 @@ def _verify_indexing_completeness(
)
def _apply_document_ingestion_hook(
documents: list[Document],
db_session: Session,
) -> list[Document]:
"""Apply the Document Ingestion hook to each document in the batch.
- HookSkipped / HookSoftFailed → document passes through unchanged.
- Response with sections=None → document is dropped (logged).
- Response with sections → document sections are replaced with the hook's output.
"""
def _build_payload(doc: Document) -> DocumentIngestionPayload:
return DocumentIngestionPayload(
document_id=doc.id or "",
title=doc.title,
semantic_identifier=doc.semantic_identifier,
source=doc.source.value if doc.source is not None else "",
sections=[
DocumentIngestionSection(
text=s.text if isinstance(s, TextSection) else None,
link=s.link,
image_file_id=(
s.image_file_id if isinstance(s, ImageSection) else None
),
)
for s in doc.sections
],
metadata={
k: v if isinstance(v, list) else [v] for k, v in doc.metadata.items()
},
doc_updated_at=(
doc.doc_updated_at.isoformat() if doc.doc_updated_at else None
),
primary_owners=(
[
DocumentIngestionOwner(
display_name=o.get_semantic_name() or None,
email=o.email,
)
for o in doc.primary_owners
]
if doc.primary_owners
else None
),
secondary_owners=(
[
DocumentIngestionOwner(
display_name=o.get_semantic_name() or None,
email=o.email,
)
for o in doc.secondary_owners
]
if doc.secondary_owners
else None
),
)
def _apply_result(
doc: Document,
hook_result: DocumentIngestionResponse | HookSkipped | HookSoftFailed,
) -> Document | None:
"""Return the modified doc, original doc (skip/soft-fail), or None (drop)."""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return doc
if not hook_result.sections:
reason = hook_result.rejection_reason or "Document rejected by hook"
logger.info(
f"Document ingestion hook dropped document doc_id={doc.id!r}: {reason}"
)
return None
new_sections: list[TextSection | ImageSection] = []
for s in hook_result.sections:
if s.image_file_id is not None:
new_sections.append(
ImageSection(image_file_id=s.image_file_id, link=s.link)
)
elif s.text is not None:
new_sections.append(TextSection(text=s.text, link=s.link))
else:
logger.warning(
f"Document ingestion hook returned a section with neither text nor "
f"image_file_id for doc_id={doc.id!r} — skipping section."
)
if not new_sections:
logger.info(
f"Document ingestion hook produced no valid sections for doc_id={doc.id!r} — dropping document."
)
return None
return doc.model_copy(update={"sections": new_sections})
if not documents:
return documents
# Run the hook for the first document. If it returns HookSkipped the hook
# is not configured — skip the remaining N-1 DB lookups.
first_doc = documents[0]
first_payload = _build_payload(first_doc).model_dump()
first_hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.DOCUMENT_INGESTION,
payload=first_payload,
response_type=DocumentIngestionResponse,
)
if isinstance(first_hook_result, HookSkipped):
return documents
result: list[Document] = []
first_applied = _apply_result(first_doc, first_hook_result)
if first_applied is not None:
result.append(first_applied)
for doc in documents[1:]:
payload = _build_payload(doc).model_dump()
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.DOCUMENT_INGESTION,
payload=payload,
response_type=DocumentIngestionResponse,
)
applied = _apply_result(doc, hook_result)
if applied is not None:
result.append(applied)
return result
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -794,6 +930,7 @@ def index_doc_batch(
document_indices: list[DocumentIndex],
request_id: str | None,
tenant_id: str,
db_session: Session,
adapter: IndexingBatchAdapter,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
@@ -818,6 +955,7 @@ def index_doc_batch(
)
filtered_documents = filter_fnc(document_batch)
filtered_documents = _apply_document_ingestion_hook(filtered_documents, db_session)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult.empty(len(filtered_documents))
@@ -1005,6 +1143,7 @@ def run_indexing_pipeline(
document_batch=document_batch,
request_id=request_id,
tenant_id=tenant_id,
db_session=db_session,
adapter=adapter,
enable_contextual_rag=enable_contextual_rag,
llm=llm,

View File

@@ -175,6 +175,28 @@ def _strip_tool_content_from_messages(
return result
def _fix_tool_user_message_ordering(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Insert a synthetic assistant message between tool and user messages.
Some models (e.g. Mistral on Azure) require strict message ordering where
a user message cannot immediately follow a tool message. This function
inserts a minimal assistant message to bridge the gap.
"""
if len(messages) < 2:
return messages
result: list[dict[str, Any]] = [messages[0]]
for msg in messages[1:]:
prev_role = result[-1].get("role")
curr_role = msg.get("role")
if prev_role == "tool" and curr_role == "user":
result.append({"role": "assistant", "content": "Noted. Continuing."})
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
@@ -576,6 +598,18 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Some models (e.g. Mistral) reject a user message
# immediately after a tool message. Insert a synthetic
# assistant bridge message to satisfy the ordering
# constraint. Check both the provider and the deployment/
# model name to catch Mistral hosted on Azure.
model_or_deployment = (
self._deployment_name or self._model_version or ""
).lower()
is_mistral_model = is_mistral or "mistral" in model_or_deployment
if is_mistral_model:
messages = _fix_tool_user_message_ordering(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -147,6 +147,7 @@ class UserInfo(BaseModel):
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
memories: list[MemoryItem] | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -191,10 +192,7 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
enable_memory_tool=user.enable_memory_tool,
memories=[
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in (user.memories or [])
],
memories=memories or [],
user_preferences=user.user_preferences or "",
),
)

View File

@@ -57,6 +57,7 @@ from onyx.db.user_preferences import activate_user
from onyx.db.user_preferences import deactivate_user
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
from onyx.db.user_preferences import get_latest_access_token_for_user
from onyx.db.user_preferences import get_memories_for_user
from onyx.db.user_preferences import update_assistant_preferences
from onyx.db.user_preferences import update_user_assistant_visibility
from onyx.db.user_preferences import update_user_auto_scroll
@@ -823,6 +824,11 @@ def verify_user_logged_in(
[],
),
)
memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
@@ -833,6 +839,7 @@ def verify_user_logged_in(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
memories=memories,
)
return user_info
@@ -930,7 +937,8 @@ def update_user_personalization_api(
else user.enable_memory_tool
)
existing_memories = [
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
new_memories = (
request.memories if request.memories is not None else existing_memories

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,11 +2,25 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -1,3 +1,4 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -708,7 +709,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter_queue.qsize()}")

View File

@@ -1,5 +1,6 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -11,7 +12,6 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use default emitter if none provided
# Use a discard emitter if none provided (packets go nowhere)
if emitter is None:
emitter = get_default_emitter()
emitter = Emitter(merged_queue=queue.Queue())
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
dynamic_schema_info=None,
)

View File

@@ -458,6 +458,27 @@ def run_async_sync_no_cancel(coro: Awaitable[T]) -> T:
return future.result()
def run_multiple_in_background(
funcs: list[Callable[[], None]],
thread_name_prefix: str = "worker",
) -> ThreadPoolExecutor:
"""Submit multiple callables to a ``ThreadPoolExecutor`` with context propagation.
Copies the current ``contextvars`` context once and runs every callable
inside that copy, which is important for preserving tenant IDs and other
context-local state across threads.
Returns the executor so the caller can ``shutdown()`` when done.
"""
ctx = contextvars.copy_context()
executor = ThreadPoolExecutor(
max_workers=len(funcs), thread_name_prefix=thread_name_prefix
)
for func in funcs:
executor.submit(ctx.run, func)
return executor
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any

View File

@@ -1,239 +0,0 @@
"""Tests for GoogleDriveConnector.resolve_errors against real Google Drive."""
import json
import os
from collections.abc import Callable
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import (
ALL_EXPECTED_HIERARCHY_NODES,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
_DRIVE_ID_MAPPING_PATH = os.path.join(
os.path.dirname(__file__), "drive_id_mapping.json"
)
def _load_web_view_links(file_ids: list[int]) -> list[str]:
with open(_DRIVE_ID_MAPPING_PATH) as f:
mapping: dict[str, str] = json.load(f)
return [mapping[str(fid)] for fid in file_ids]
def _build_failures(web_view_links: list[str]) -> list[ConnectorFailure]:
return [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=link,
document_link=link,
),
failure_message=f"Synthetic failure for {link}",
)
for link in web_view_links
]
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_single_file(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve a single known file and verify we get back exactly one Document."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
web_view_links = _load_web_view_links([0])
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
assert len(docs) == 1
assert len(new_failures) == 0
assert docs[0].semantic_identifier == "file_0.txt"
# Should yield at least one hierarchy node (the file's parent folder chain)
assert len(hierarchy_nodes) > 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_multiple_files(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve multiple files across different folders via batch API."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
# Pick files from different folders: admin files (0-4), shared drive 1 (20-24), folder_2 (45-49)
file_ids = [0, 1, 20, 21, 45]
web_view_links = _load_web_view_links(file_ids)
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
assert len(new_failures) == 0
retrieved_names = {doc.semantic_identifier for doc in docs}
expected_names = {f"file_{fid}.txt" for fid in file_ids}
assert expected_names == retrieved_names
# Files span multiple folders, so we should get hierarchy nodes
assert len(hierarchy_nodes) > 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_hierarchy_nodes_are_valid(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Verify that hierarchy nodes from resolve_errors match expected structure."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
# File in folder_1 (inside shared_drive_1) — should walk up to shared_drive_1 root
web_view_links = _load_web_view_links([25])
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
node_ids = {node.raw_node_id for node in hierarchy_nodes}
# File 25 is in folder_1 which is inside shared_drive_1.
# The parent walk must yield at least these two ancestors.
assert (
FOLDER_1_ID in node_ids
), f"Expected folder_1 ({FOLDER_1_ID}) in hierarchy nodes, got: {node_ids}"
assert (
SHARED_DRIVE_1_ID in node_ids
), f"Expected shared_drive_1 ({SHARED_DRIVE_1_ID}) in hierarchy nodes, got: {node_ids}"
for node in hierarchy_nodes:
if node.raw_node_id not in ALL_EXPECTED_HIERARCHY_NODES:
continue
expected = ALL_EXPECTED_HIERARCHY_NODES[node.raw_node_id]
assert node.display_name == expected.display_name, (
f"Display name mismatch for {node.raw_node_id}: "
f"expected '{expected.display_name}', got '{node.display_name}'"
)
assert node.node_type == expected.node_type, (
f"Node type mismatch for {node.raw_node_id}: "
f"expected '{expected.node_type}', got '{node.node_type}'"
)
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_with_invalid_link(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve with a mix of valid and invalid links — invalid ones yield ConnectorFailure."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
valid_links = _load_web_view_links([0])
invalid_link = "https://drive.google.com/file/d/NONEXISTENT_FILE_ID_12345"
failures = _build_failures(valid_links + [invalid_link])
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
assert len(docs) == 1
assert docs[0].semantic_identifier == "file_0.txt"
assert len(new_failures) == 1
assert new_failures[0].failed_document is not None
assert new_failures[0].failed_document.document_id == invalid_link
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_empty_errors(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolving an empty error list should yield nothing."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
results = list(connector.resolve_errors([]))
assert len(results) == 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_entity_failures_are_skipped(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Entity failures (not document failures) should be skipped by resolve_errors."""
from onyx.connectors.models import EntityFailure
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
entity_failure = ConnectorFailure(
failed_entity=EntityFailure(entity_id="some_stage"),
failure_message="retrieval failure",
)
results = list(connector.resolve_errors([entity_failure]))
assert len(results) == 0

View File

@@ -27,11 +27,13 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -13,6 +13,7 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -20,7 +21,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -16,7 +17,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)

View File

@@ -0,0 +1,173 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

View File

@@ -0,0 +1,754 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
setup = _make_setup(n_models=2)
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 failed")
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_stop_button_calls_completion_for_all_models(self) -> None:
"""llm_loop_completion_handle must be called for all models when the stop button fires.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers main-thread completion.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator. The finally block sets drain_done (signalling
emitters to stop blocking), waits for workers via executor.shutdown(wait=True),
then calls llm_loop_completion_handle for each successful model from the main
thread.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
completion_called = threading.Event()
def emit_then_block_until_drain(**kwargs: Any) -> None:
"""Emit one packet (to give the drain loop a yield point), then block
until drain_done is set — simulating a mid-stream LLM call that exits
promptly once the emitter signals shutdown.
"""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
# Block until drain_done is set by gen.close(). The Emitter's _drain_done
# is the same Event that _run_models sets, so this unblocks promptly.
emitter._drain_done.wait(timeout=5)
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_block_until_drain,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# gen.close() → GeneratorExit → finally → drain_done.set() →
# executor.shutdown(wait=True) → main thread completes models.
gen.close()
assert (
completion_called.is_set()
), "main thread must call completion for the successful model"
assert mock_handle.call_count == 1
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
"""B1 regression: model finishes BEFORE GeneratorExit fires.
The worker exits _run_model before drain_done is set. When gen.close()
fires afterward, the finally block sets drain_done, waits for workers
(already done), then the main thread calls llm_loop_completion_handle.
Contrast with test_http_disconnect_completion_via_generator_exit, which
tests the opposite ordering (worker finishes AFTER disconnect).
"""
import threading
import time
completion_called = threading.Event()
def emit_and_return_immediately(**kwargs: Any) -> None:
# Emit one packet so the drain loop has something to yield, then return
# immediately — no blocking. The worker will be done in microseconds.
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_and_return_immediately,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
first = next(gen)
assert isinstance(first, Packet)
# Give the worker thread time to finish completely (emit + return +
# finally + self-completion check). It does almost no work, so 100 ms
# is far more than enough while still keeping the test fast.
time.sleep(0.1)
# Now close — worker is already done, so else-branch handles completion.
gen.close()
assert completion_called.wait(
timeout=5
), "disconnect handler must call completion for a model that already finished"
assert mock_handle.call_count == 1, "completion must be called exactly once"
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
"""B2 regression: stop-button must NOT call completion for an errored model.
When model 0 raises an exception, its reserved ChatMessage must not be
saved with 'stopped by user' — that message is wrong for a model that
errored. llm_loop_completion_handle must only be called for non-errored
models when the stop button fires.
"""
def fail_model_0(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 errored")
# Model 1: run forever (stop button fires before it finishes)
time.sleep(10)
setup = _make_setup(n_models=2)
# Return False immediately so the stop-button path fires while model 1
# is still sleeping (model 0 has already errored by then).
setup.check_is_connected = lambda: False
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Completion must NOT be called for model 0 (it errored).
# It MAY be called for model 1 (still in-flight when stop fired).
for call in mock_handle.call_args_list:
assert (
call.kwargs.get("llm") is not setup.llms[0]
), "llm_loop_completion_handle must not be called for the errored model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -0,0 +1,198 @@
import io
from typing import cast
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from onyx.file_processing.extract_file_text import xlsx_to_text
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
wb = openpyxl.Workbook()
if wb.active is not None:
wb.remove(cast(Worksheet, wb.active))
for sheet_name, rows in sheets.items():
ws = wb.create_sheet(title=sheet_name)
for row in rows:
ws.append(row)
buf = io.BytesIO()
wb.save(buf)
buf.seek(0)
return buf
class TestXlsxToText:
def test_single_sheet_basic(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["Name", "Age"],
["Alice", "30"],
["Bob", "25"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "Name" in lines[0]
assert "Age" in lines[0]
assert "Alice" in lines[1]
assert "30" in lines[1]
assert "Bob" in lines[2]
def test_multiple_sheets_separated(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [["a", "b"]],
"Sheet2": [["c", "d"]],
}
)
result = xlsx_to_text(xlsx)
# TEXT_SECTION_SEPARATOR is "\n\n"
assert "\n\n" in result
parts = result.split("\n\n")
assert any("a" in p for p in parts)
assert any("c" in p for p in parts)
def test_empty_cells(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "b"],
["", "c", ""],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
def test_commas_in_cells_are_quoted(self) -> None:
"""Cells containing commas should be quoted in CSV output."""
xlsx = _make_xlsx(
{
"Sheet1": [
["hello, world", "normal"],
]
}
)
result = xlsx_to_text(xlsx)
assert '"hello, world"' in result
def test_empty_workbook(self) -> None:
xlsx = _make_xlsx({"Sheet1": []})
result = xlsx_to_text(xlsx)
assert result.strip() == ""
def test_long_empty_row_run_capped(self) -> None:
"""Runs of >2 empty rows should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["header"],
[""],
[""],
[""],
[""],
["data"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
assert len(lines) == 4
assert "header" in lines[0]
assert "data" in lines[-1]
def test_long_empty_col_run_capped(self) -> None:
"""Runs of >2 empty columns should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "", "", "b"],
["c", "", "", "", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
# Each row should have 4 fields (a + 2 empty + b), not 5
# csv format: a,,,b (3 commas = 4 fields)
first_line = lines[0].strip()
# Count commas to verify column reduction
assert first_line.count(",") == 3
def test_short_empty_runs_kept(self) -> None:
"""Runs of <=2 empty rows/cols should be preserved."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "b"],
["", ""],
["", ""],
["c", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# All 4 rows preserved (2 empty rows <= threshold)
assert len(lines) == 4
def test_bad_zip_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="test.xlsx")
assert result == ""
def test_bad_zip_tilde_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
assert result == ""
def test_large_sparse_sheet(self) -> None:
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
rows: list[list[str]] = [["row1_data"]]
rows.extend([[""] for _ in range(10)])
rows.append(["row2_data"])
xlsx = _make_xlsx({"Sheet1": rows})
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
assert len(lines) == 4
assert "row1_data" in lines[0]
assert "row2_data" in lines[-1]
def test_quotes_in_cells(self) -> None:
"""Cells containing quotes should be properly escaped."""
xlsx = _make_xlsx(
{
"Sheet1": [
['say "hello"', "normal"],
]
}
)
result = xlsx_to_text(xlsx)
# csv.writer escapes quotes by doubling them
assert '""hello""' in result
def test_each_row_is_separate_line(self) -> None:
"""Each row should produce its own line (regression for writerow vs writerows)."""
xlsx = _make_xlsx(
{
"Sheet1": [
["r1c1", "r1c2"],
["r2c1", "r2c2"],
["r3c1", "r3c2"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "r1c1" in lines[0] and "r1c2" in lines[0]
assert "r2c1" in lines[1] and "r2c2" in lines[1]
assert "r3c1" in lines[2] and "r3c2" in lines[2]

View File

@@ -2,6 +2,7 @@ import threading
from typing import Any
from typing import cast
from typing import List
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch
@@ -12,8 +13,13 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.document_ingestion import DocumentIngestionResponse
from onyx.hooks.points.document_ingestion import DocumentIngestionSection
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import _apply_document_ingestion_hook
from onyx.indexing.indexing_pipeline import add_contextual_summaries
from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.indexing_pipeline import process_image_sections
@@ -223,3 +229,148 @@ def test_contextual_rag(
count += 1
assert chunk.doc_summary == doc_summary
assert chunk.chunk_context == chunk_context
# ---------------------------------------------------------------------------
# _apply_document_ingestion_hook
# ---------------------------------------------------------------------------
_PATCH_EXECUTE_HOOK = "onyx.indexing.indexing_pipeline.execute_hook"
def _make_doc(
doc_id: str = "doc1",
sections: list[TextSection | ImageSection] | None = None,
) -> Document:
if sections is None:
sections = [TextSection(text="Hello", link="http://example.com")]
return Document(
id=doc_id,
title="Test Doc",
semantic_identifier="test-doc",
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.FILE,
metadata={},
)
def test_document_ingestion_hook_skipped_passes_through() -> None:
doc = _make_doc()
with patch(_PATCH_EXECUTE_HOOK, return_value=HookSkipped()):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert result == [doc]
def test_document_ingestion_hook_soft_failed_passes_through() -> None:
doc = _make_doc()
with patch(_PATCH_EXECUTE_HOOK, return_value=HookSoftFailed()):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert result == [doc]
def test_document_ingestion_hook_none_sections_drops_document() -> None:
doc = _make_doc()
with patch(
_PATCH_EXECUTE_HOOK,
return_value=DocumentIngestionResponse(
sections=None, rejection_reason="PII detected"
),
):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert result == []
def test_document_ingestion_hook_all_invalid_sections_drops_document() -> None:
"""A non-empty list where every section has neither text nor image_file_id drops the doc."""
doc = _make_doc()
with patch(
_PATCH_EXECUTE_HOOK,
return_value=DocumentIngestionResponse(sections=[DocumentIngestionSection()]),
):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert result == []
def test_document_ingestion_hook_empty_sections_drops_document() -> None:
doc = _make_doc()
with patch(
_PATCH_EXECUTE_HOOK,
return_value=DocumentIngestionResponse(sections=[]),
):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert result == []
def test_document_ingestion_hook_rewrites_text_sections() -> None:
doc = _make_doc(sections=[TextSection(text="original", link="http://a.com")])
with patch(
_PATCH_EXECUTE_HOOK,
return_value=DocumentIngestionResponse(
sections=[DocumentIngestionSection(text="rewritten", link="http://b.com")]
),
):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert len(result) == 1
assert len(result[0].sections) == 1
section = result[0].sections[0]
assert isinstance(section, TextSection)
assert section.text == "rewritten"
assert section.link == "http://b.com"
def test_document_ingestion_hook_preserves_image_section_order() -> None:
"""Hook receives all sections including images and controls final ordering."""
image = ImageSection(image_file_id="img-1", link=None)
doc = _make_doc(
sections=cast(
list[TextSection | ImageSection],
[TextSection(text="original", link=None), image],
)
)
# Hook moves the image before the text section
with patch(
_PATCH_EXECUTE_HOOK,
return_value=DocumentIngestionResponse(
sections=[
DocumentIngestionSection(image_file_id="img-1", link=None),
DocumentIngestionSection(text="rewritten", link=None),
]
),
):
result = _apply_document_ingestion_hook([doc], MagicMock())
assert len(result) == 1
sections = result[0].sections
assert len(sections) == 2
assert (
isinstance(sections[0], ImageSection) and sections[0].image_file_id == "img-1"
)
assert isinstance(sections[1], TextSection) and sections[1].text == "rewritten"
def test_document_ingestion_hook_mixed_batch() -> None:
"""Drop one doc, rewrite another, pass through a third."""
doc_drop = _make_doc(doc_id="drop")
doc_rewrite = _make_doc(doc_id="rewrite")
doc_skip = _make_doc(doc_id="skip")
def _side_effect(**kwargs: Any) -> Any:
doc_id = kwargs["payload"]["document_id"]
if doc_id == "drop":
return DocumentIngestionResponse(sections=None)
if doc_id == "rewrite":
return DocumentIngestionResponse(
sections=[DocumentIngestionSection(text="new text", link=None)]
)
return HookSkipped()
with patch(_PATCH_EXECUTE_HOOK, side_effect=_side_effect):
result = _apply_document_ingestion_hook(
[doc_drop, doc_rewrite, doc_skip], MagicMock()
)
assert len(result) == 2
ids = {d.id for d in result}
assert ids == {"rewrite", "skip"}
rewritten = next(d for d in result if d.id == "rewrite")
assert isinstance(rewritten.sections[0], TextSection)
assert rewritten.sections[0].text == "new text"

View File

@@ -1,6 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
import queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
def emitter_queue() -> queue.Queue:
return queue.Queue()
@pytest.fixture
def emitter(emitter_queue: queue.Queue) -> Emitter:
return Emitter(merged_queue=emitter_queue)
@pytest.fixture
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
def test_emit_start_emits_memory_tool_start_packet(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
) -> None:
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolStart)
assert packet.placement == placement
assert packet.placement is not None
assert packet.placement.turn_index == placement.turn_index
assert packet.placement.tab_index == placement.tab_index
assert packet.placement.model_index == 0 # emitter stamps model_index=0
def test_emit_start_with_different_placement(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
) -> None:
placement = Placement(turn_index=2, tab_index=1)
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert packet.placement.turn_index == 2
assert packet.placement.tab_index == 1
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
memory="User prefers Python",
)
# The delta packet should be in the queue
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers Python"
assert packet.obj.operation == "add"
assert packet.obj.memory_id is None
assert packet.obj.index is None
assert packet.placement == placement
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
def test_run_emits_delta_for_update_operation(
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
memory="User prefers light mode",
)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers light mode"
assert packet.obj.operation == "update"

View File

@@ -231,6 +231,23 @@ import { Hoverable } from "@opal/core";
# Best Practices
## 0. Size Variant Defaults
**When using `SizeVariants` (or any subset like `PaddingVariants`, `RoundingVariants`) as a prop
type, always default to `"md"`.**
**Reason:** `"md"` is the standard middle-of-the-road preset across the design system. Consistent
defaults make components predictable — callers only need to specify a size when they want something
other than the norm.
```typescript
// ✅ Good — default to "md"
function MyCard({ padding = "md", rounding = "md" }: MyCardProps) { ... }
// ❌ Bad — arbitrary or inconsistent defaults
function MyCard({ padding = "sm", rounding = "lg" }: MyCardProps) { ... }
```
## 1. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**

View File

@@ -29,7 +29,7 @@ export const BackgroundVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{BACKGROUND_VARIANTS.map((bg) => (
<Card key={bg} backgroundVariant={bg} borderVariant="solid">
<Card key={bg} background={bg} border="solid">
<p>backgroundVariant: {bg}</p>
</Card>
))}
@@ -41,7 +41,7 @@ export const BorderVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{BORDER_VARIANTS.map((border) => (
<Card key={border} borderVariant={border}>
<Card key={border} border={border}>
<p>borderVariant: {border}</p>
</Card>
))}
@@ -53,7 +53,7 @@ export const PaddingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{PADDING_VARIANTS.map((padding) => (
<Card key={padding} paddingVariant={padding} borderVariant="solid">
<Card key={padding} padding={padding} border="solid">
<p>paddingVariant: {padding}</p>
</Card>
))}
@@ -65,7 +65,7 @@ export const RoundingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{ROUNDING_VARIANTS.map((rounding) => (
<Card key={rounding} roundingVariant={rounding} borderVariant="solid">
<Card key={rounding} rounding={rounding} border="solid">
<p>roundingVariant: {rounding}</p>
</Card>
))}
@@ -84,9 +84,9 @@ export const AllCombinations: Story = {
BORDER_VARIANTS.map((border) => (
<Card
key={`${padding}-${bg}-${border}`}
paddingVariant={padding}
backgroundVariant={bg}
borderVariant={border}
padding={padding}
background={bg}
border={border}
>
<p className="text-xs">
bg: {bg}, border: {border}

View File

@@ -8,30 +8,30 @@ A plain container component with configurable background, border, padding, and r
Padding and rounding are controlled independently:
| `paddingVariant` | Class |
|------------------|---------|
| `"lg"` | `p-6` |
| `"md"` | `p-4` |
| `"sm"` | `p-2` |
| `"xs"` | `p-1` |
| `"2xs"` | `p-0.5` |
| `"fit"` | `p-0` |
| `padding` | Class |
|-----------|---------|
| `"lg"` | `p-6` |
| `"md"` | `p-4` |
| `"sm"` | `p-2` |
| `"xs"` | `p-1` |
| `"2xs"` | `p-0.5` |
| `"fit"` | `p-0` |
| `roundingVariant` | Class |
|-------------------|--------------|
| `"xs"` | `rounded-04` |
| `"sm"` | `rounded-08` |
| `"md"` | `rounded-12` |
| `"lg"` | `rounded-16` |
| `rounding` | Class |
|------------|--------------|
| `"xs"` | `rounded-04` |
| `"sm"` | `rounded-08` |
| `"md"` | `rounded-12` |
| `"lg"` | `rounded-16` |
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset |
| `roundingVariant` | `RoundingVariants` | `"md"` | Border-radius preset |
| `backgroundVariant` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
| `borderVariant` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
| `padding` | `PaddingVariants` | `"sm"` | Padding preset |
| `rounding` | `RoundingVariants` | `"md"` | Border-radius preset |
| `background` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
| `border` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| `children` | `React.ReactNode` | — | Card content |
@@ -47,17 +47,17 @@ import { Card } from "@opal/components";
</Card>
// Large padding + rounding with solid border
<Card paddingVariant="lg" roundingVariant="lg" borderVariant="solid">
<Card padding="lg" rounding="lg" border="solid">
<p>Spacious card</p>
</Card>
// Compact card with solid border
<Card paddingVariant="xs" roundingVariant="sm" borderVariant="solid">
<Card padding="xs" rounding="sm" border="solid">
<p>Compact card</p>
</Card>
// Empty state card
<Card backgroundVariant="none" borderVariant="dashed">
<Card background="none" border="dashed">
<p>No items yet</p>
</Card>
```

View File

@@ -1,5 +1,6 @@
import "@opal/components/cards/card/styles.css";
import type { PaddingVariants, RoundingVariants } from "@opal/types";
import { cardPaddingVariants, cardRoundingVariants } from "@opal/shared";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
@@ -22,9 +23,9 @@ type CardProps = {
* | `"2xs"` | `p-0.5` |
* | `"fit"` | `p-0` |
*
* @default "sm"
* @default "md"
*/
paddingVariant?: PaddingVariants;
padding?: PaddingVariants;
/**
* Border-radius preset.
@@ -38,7 +39,7 @@ type CardProps = {
*
* @default "md"
*/
roundingVariant?: RoundingVariants;
rounding?: RoundingVariants;
/**
* Background fill intensity.
@@ -48,7 +49,7 @@ type CardProps = {
*
* @default "light"
*/
backgroundVariant?: BackgroundVariant;
background?: BackgroundVariant;
/**
* Border style.
@@ -58,7 +59,7 @@ type CardProps = {
*
* @default "none"
*/
borderVariant?: BorderVariant;
border?: BorderVariant;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
@@ -66,47 +67,27 @@ type CardProps = {
children?: React.ReactNode;
};
// ---------------------------------------------------------------------------
// Mappings
// ---------------------------------------------------------------------------
const paddingForVariant: Record<PaddingVariants, string> = {
lg: "p-6",
md: "p-4",
sm: "p-2",
xs: "p-1",
"2xs": "p-0.5",
fit: "p-0",
};
const roundingForVariant: Record<RoundingVariants, string> = {
lg: "rounded-16",
md: "rounded-12",
sm: "rounded-08",
xs: "rounded-04",
};
// ---------------------------------------------------------------------------
// Card
// ---------------------------------------------------------------------------
function Card({
paddingVariant = "sm",
roundingVariant = "md",
backgroundVariant = "light",
borderVariant = "none",
padding: paddingProp = "md",
rounding: roundingProp = "md",
background = "light",
border = "none",
ref,
children,
}: CardProps) {
const padding = paddingForVariant[paddingVariant];
const rounding = roundingForVariant[roundingVariant];
const padding = cardPaddingVariants[paddingProp];
const rounding = cardRoundingVariants[roundingProp];
return (
<div
ref={ref}
className={cn("opal-card", padding, rounding)}
data-background={backgroundVariant}
data-border={borderVariant}
data-background={background}
data-border={border}
>
{children}
</div>

View File

@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react";
import { EmptyMessageCard } from "@opal/components";
import { SvgSparkle, SvgUsers } from "@opal/icons";
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
const PADDING_VARIANTS = ["fit", "2xs", "xs", "sm", "md", "lg"] as const;
const meta: Meta<typeof EmptyMessageCard> = {
title: "opal/components/EmptyMessageCard",
@@ -26,14 +26,14 @@ export const WithCustomIcon: Story = {
},
};
export const SizeVariants: Story = {
export const PaddingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{SIZE_VARIANTS.map((size) => (
{PADDING_VARIANTS.map((padding) => (
<EmptyMessageCard
key={size}
sizeVariant={size}
title={`sizeVariant: ${size}`}
key={padding}
padding={padding}
title={`padding: ${padding}`}
/>
))}
</div>

View File

@@ -6,12 +6,12 @@ A pre-configured Card for empty states. Renders a transparent card with a dashed
## Props
| Prop | Type | Default | Description |
| ----------------- | --------------------------- | ---------- | ------------------------------------------------ |
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
| `title` | `string` | — | Primary message text (required) |
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset for the card |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| Prop | Type | Default | Description |
| --------- | --------------------------- | ---------- | -------------------------------- |
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
| `title` | `string` | — | Primary message text (required) |
| `padding` | `PaddingVariants` | `"sm"` | Padding preset for the card |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
## Usage
@@ -26,5 +26,5 @@ import { SvgSparkle, SvgFileText } from "@opal/icons";
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
// With custom padding
<EmptyMessageCard paddingVariant="xs" icon={SvgFileText} title="No documents available." />
<EmptyMessageCard padding="xs" icon={SvgFileText} title="No documents available." />
```

View File

@@ -14,8 +14,8 @@ type EmptyMessageCardProps = {
/** Primary message text. */
title: string;
/** Padding preset for the card. */
paddingVariant?: PaddingVariants;
/** Padding preset for the card. @default "md" */
padding?: PaddingVariants;
/** Ref forwarded to the root Card div. */
ref?: React.Ref<HTMLDivElement>;
@@ -28,15 +28,16 @@ type EmptyMessageCardProps = {
function EmptyMessageCard({
icon = SvgEmpty,
title,
paddingVariant = "sm",
padding = "md",
ref,
}: EmptyMessageCardProps) {
return (
<Card
ref={ref}
backgroundVariant="none"
borderVariant="dashed"
paddingVariant={paddingVariant}
background="none"
border="dashed"
padding={padding}
rounding="md"
>
<Content
icon={icon}

View File

@@ -2,11 +2,11 @@
**Import:** `import { SelectCard, type SelectCardProps } from "@opal/components";`
A stateful interactive card — the card counterpart to [`SelectButton`](../../buttons/select-button/README.md). Built on `Interactive.Stateful` (Slot) with a structural `<div>` that owns padding, rounding, border, and overflow.
A stateful interactive card — the card counterpart to [`SelectButton`](../../buttons/select-button/README.md). Built on `Interactive.Stateful` (Slot) with a structural `<div>` that owns padding, rounding, border, and overflow. Always uses the `select-card` Interactive.Stateful variant internally.
## Relationship to Card
`Card` is a plain, non-interactive container. `SelectCard` adds stateful interactivity (hover, active, disabled, state-driven colors) by wrapping its root div with `Interactive.Stateful`. The relationship mirrors `Button` (stateless) vs `SelectButton` (stateful).
`Card` is a plain, non-interactive container. `SelectCard` adds stateful interactivity (hover, active, disabled, state-driven colors) by wrapping its root div with `Interactive.Stateful`. Both share the same independent `padding` / `rounding` API.
## Relationship to SelectButton
@@ -18,15 +18,15 @@ Interactive.Stateful → structural element → content
The key differences:
- SelectCard renders a `<div>` (not `Interactive.Container`) — cards have their own rounding scale (one notch larger than buttons) and don't need Container's height/min-width.
- SelectCard renders a `<div>` (not `Interactive.Container`) — cards have their own rounding scale and don't need Container's height/min-width.
- SelectCard has no `foldable` prop — use `Interactive.Foldable` directly inside children.
- SelectCard's children are fully composable — use `CardHeaderLayout`, `ContentAction`, `Content`, buttons, etc. inside.
## Architecture
```
Interactive.Stateful <- variant, state, interaction, disabled, onClick
└─ div.opal-select-card <- padding, rounding, border, overflow
Interactive.Stateful (variant="select-card") <- state, interaction, disabled, onClick
└─ div.opal-select-card <- padding, rounding, border, overflow
└─ children (composable)
```
@@ -34,28 +34,36 @@ The `Interactive.Stateful` Slot merges onto the div, producing a single DOM elem
## Props
Inherits **all** props from `InteractiveStatefulProps` (variant, state, interaction, onClick, href, etc.) plus:
Inherits **all** props from `InteractiveStatefulProps` (except `variant`, which is hardcoded to `select-card`) plus:
| Prop | Type | Default | Description |
|---|---|---|---|
| `sizeVariant` | `ContainerSizeVariants` | `"lg"` | Controls padding and border-radius |
| `padding` | `PaddingVariants` | `"sm"` | Padding preset |
| `rounding` | `RoundingVariants` | `"lg"` | Border-radius preset |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| `children` | `React.ReactNode` | — | Card content |
### Padding scale
| `padding` | Class |
|-----------|---------|
| `"lg"` | `p-6` |
| `"md"` | `p-4` |
| `"sm"` | `p-2` |
| `"xs"` | `p-1` |
| `"2xs"` | `p-0.5` |
| `"fit"` | `p-0` |
### Rounding scale
Cards use a bumped-up rounding scale compared to buttons:
| `rounding` | Class |
|------------|--------------|
| `"xs"` | `rounded-04` |
| `"sm"` | `rounded-08` |
| `"md"` | `rounded-12` |
| `"lg"` | `rounded-16` |
| Size | Rounding | Effective radius |
|---|---|---|
| `lg` | `rounded-16` | 1rem (16px) |
| `md``sm` | `rounded-12` | 0.75rem (12px) |
| `xs``2xs` | `rounded-08` | 0.5rem (8px) |
| `fit` | `rounded-16` | 1rem (16px) |
### Recommended variant: `select-card`
The `select-card` Interactive.Stateful variant is specifically designed for cards. Unlike `select-heavy` (which only changes foreground color between empty and filled), `select-card` gives the filled state a visible background — important on larger surfaces where background carries more of the visual distinction.
### State colors (`select-card` variant)
| State | Rest background | Rest foreground |
|---|---|---|
@@ -82,7 +90,7 @@ All background and foreground colors come from the Interactive.Stateful CSS, not
import { SelectCard } from "@opal/components";
import { CardHeaderLayout } from "@opal/layouts";
<SelectCard variant="select-card" state="selected" onClick={handleClick}>
<SelectCard state="selected" onClick={handleClick}>
<CardHeaderLayout
icon={SvgGlobe}
title="Google"
@@ -100,7 +108,7 @@ import { CardHeaderLayout } from "@opal/layouts";
### Disconnected state (clickable)
```tsx
<SelectCard variant="select-card" state="empty" onClick={handleConnect}>
<SelectCard state="empty" onClick={handleConnect}>
<CardHeaderLayout
icon={SvgCloud}
title="OpenAI"
@@ -115,7 +123,7 @@ import { CardHeaderLayout } from "@opal/layouts";
### With foldable hover-reveal
```tsx
<SelectCard variant="select-card" state="filled">
<SelectCard state="filled">
<CardHeaderLayout
icon={SvgCloud}
title="OpenAI"

View File

@@ -21,7 +21,8 @@ const withTooltipProvider: Decorator = (Story) => (
);
const STATES = ["empty", "filled", "selected"] as const;
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
const PADDING_VARIANTS = ["fit", "2xs", "xs", "sm", "md", "lg"] as const;
const ROUNDING_VARIANTS = ["xs", "sm", "md", "lg"] as const;
const meta = {
title: "opal/components/SelectCard",
@@ -44,7 +45,7 @@ type Story = StoryObj<typeof meta>;
export const Default: Story = {
render: () => (
<div className="w-96">
<SelectCard variant="select-card" state="empty">
<SelectCard state="empty">
<div className="p-2">
<Content
sizePreset="main-ui"
@@ -63,7 +64,7 @@ export const AllStates: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{STATES.map((state) => (
<SelectCard key={state} variant="select-card" state={state}>
<SelectCard key={state} state={state}>
<div className="p-2">
<Content
sizePreset="main-ui"
@@ -82,11 +83,7 @@ export const AllStates: Story = {
export const Clickable: Story = {
render: () => (
<div className="w-96">
<SelectCard
variant="select-card"
state="empty"
onClick={() => alert("Card clicked")}
>
<SelectCard state="empty" onClick={() => alert("Card clicked")}>
<div className="p-2">
<Content
sizePreset="main-ui"
@@ -105,7 +102,7 @@ export const WithActions: Story = {
render: () => (
<div className="flex flex-col gap-4 w-[28rem]">
{/* Disconnected */}
<SelectCard variant="select-card" state="empty" onClick={() => {}}>
<SelectCard state="empty" onClick={() => {}}>
<div className="flex flex-row items-stretch w-full">
<div className="flex-1 p-2">
<Content
@@ -125,7 +122,7 @@ export const WithActions: Story = {
</SelectCard>
{/* Connected with foldable */}
<SelectCard variant="select-card" state="filled">
<SelectCard state="filled">
<div className="flex flex-row items-stretch w-full">
<div className="flex-1 p-2">
<Content
@@ -163,7 +160,7 @@ export const WithActions: Story = {
</SelectCard>
{/* Selected */}
<SelectCard variant="select-card" state="selected">
<SelectCard state="selected">
<div className="flex flex-row items-stretch w-full">
<div className="flex-1 p-2">
<Content
@@ -203,22 +200,17 @@ export const WithActions: Story = {
),
};
export const SizeVariants: Story = {
export const PaddingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{SIZE_VARIANTS.map((size) => (
<SelectCard
key={size}
variant="select-card"
state="filled"
sizeVariant={size}
>
{PADDING_VARIANTS.map((padding) => (
<SelectCard key={padding} state="filled" padding={padding}>
<Content
sizePreset="main-ui"
variant="section"
icon={SvgGlobe}
title={`sizeVariant: ${size}`}
description="Shows padding and rounding differences."
title={`paddingVariant: ${padding}`}
description="Shows padding differences."
/>
</SelectCard>
))}
@@ -226,20 +218,18 @@ export const SizeVariants: Story = {
),
};
export const SelectHeavyVariant: Story = {
export const RoundingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{STATES.map((state) => (
<SelectCard key={state} variant="select-heavy" state={state}>
<div className="p-2">
<Content
sizePreset="main-ui"
variant="section"
icon={SvgGlobe}
title={`select-heavy / ${state}`}
description="For comparison with select-card variant."
/>
</div>
{ROUNDING_VARIANTS.map((rounding) => (
<SelectCard key={rounding} state="filled" rounding={rounding}>
<Content
sizePreset="main-ui"
variant="section"
icon={SvgGlobe}
title={`roundingVariant: ${rounding}`}
description="Shows rounding differences."
/>
</SelectCard>
))}
</div>

View File

@@ -1,6 +1,6 @@
import "@opal/components/cards/select-card/styles.css";
import type { ContainerSizeVariants } from "@opal/types";
import { containerSizeVariants } from "@opal/shared";
import type { PaddingVariants, RoundingVariants } from "@opal/types";
import { cardPaddingVariants, cardRoundingVariants } from "@opal/shared";
import { cn } from "@opal/utils";
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
@@ -8,23 +8,36 @@ import { Interactive, type InteractiveStatefulProps } from "@opal/core";
// Types
// ---------------------------------------------------------------------------
type SelectCardProps = InteractiveStatefulProps & {
type SelectCardProps = Omit<InteractiveStatefulProps, "variant"> & {
/**
* Size preset — controls padding and border-radius.
* Padding preset.
*
* Padding comes from the shared size scale. Rounding follows the same
* mapping as `Card` / `Button` / `Interactive.Container`:
* | Value | Class |
* |---------|---------|
* | `"lg"` | `p-6` |
* | `"md"` | `p-4` |
* | `"sm"` | `p-2` |
* | `"xs"` | `p-1` |
* | `"2xs"` | `p-0.5` |
* | `"fit"` | `p-0` |
*
* | Size | Rounding |
* |------------|--------------|
* | `lg` | `rounded-16` |
* | `md``sm` | `rounded-12` |
* | `xs``2xs` | `rounded-08` |
* | `fit` | `rounded-16` |
*
* @default "lg"
* @default "md"
*/
sizeVariant?: ContainerSizeVariants;
padding?: PaddingVariants;
/**
* Border-radius preset.
*
* | Value | Class |
* |--------|--------------|
* | `"xs"` | `rounded-04` |
* | `"sm"` | `rounded-08` |
* | `"md"` | `rounded-12` |
* | `"lg"` | `rounded-16` |
*
* @default "md"
*/
rounding?: RoundingVariants;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
@@ -32,19 +45,6 @@ type SelectCardProps = InteractiveStatefulProps & {
children?: React.ReactNode;
};
// ---------------------------------------------------------------------------
// Rounding
// ---------------------------------------------------------------------------
const roundingForSize: Record<ContainerSizeVariants, string> = {
lg: "rounded-16",
md: "rounded-12",
sm: "rounded-12",
xs: "rounded-08",
"2xs": "rounded-08",
fit: "rounded-16",
};
// ---------------------------------------------------------------------------
// SelectCard
// ---------------------------------------------------------------------------
@@ -61,7 +61,7 @@ const roundingForSize: Record<ContainerSizeVariants, string> = {
*
* @example
* ```tsx
* <SelectCard variant="select-card" state="selected" onClick={handleClick}>
* <SelectCard state="selected" onClick={handleClick}>
* <ContentAction
* icon={SvgGlobe}
* title="Google"
@@ -72,16 +72,17 @@ const roundingForSize: Record<ContainerSizeVariants, string> = {
* ```
*/
function SelectCard({
sizeVariant = "lg",
padding: paddingProp = "md",
rounding: roundingProp = "md",
ref,
children,
...statefulProps
}: SelectCardProps) {
const { padding } = containerSizeVariants[sizeVariant];
const rounding = roundingForSize[sizeVariant];
const padding = cardPaddingVariants[paddingProp];
const rounding = cardRoundingVariants[roundingProp];
return (
<Interactive.Stateful {...statefulProps}>
<Interactive.Stateful {...statefulProps} variant="select-card">
<div ref={ref} className={cn("opal-select-card", padding, rounding)}>
{children}
</div>

View File

@@ -2,6 +2,7 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
// ---------------------------------------------------------------------------
// Types
@@ -91,7 +92,7 @@ function InteractiveSimple({
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined
: onClick
: guardPortalClick(onClick)
}
/>
);

View File

@@ -4,6 +4,7 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
import type { ButtonType, WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
@@ -153,7 +154,7 @@ function InteractiveStateful({
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined
: onClick
: guardPortalClick(onClick)
}
/>
);

View File

@@ -4,6 +4,7 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
import type { ButtonType, WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
@@ -137,7 +138,7 @@ function InteractiveStateless({
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined
: onClick
: guardPortalClick(onClick)
}
/>
);

View File

@@ -0,0 +1,28 @@
import type React from "react";
/**
* Guards an onClick handler against React synthetic event bubbling from
* portalled children (e.g. Radix Dialog overlays).
*
* React bubbles synthetic events through the **fiber tree** (component
* hierarchy), not the DOM tree. This means a click on a portalled modal
* overlay will bubble to a parent component's onClick even though the
* overlay is not a DOM descendant. This guard checks that the click
* target is actually inside the handler's DOM element before firing.
*/
function guardPortalClick<E extends React.MouseEvent>(
onClick: ((e: E) => void) | undefined
): ((e: E) => void) | undefined {
if (!onClick) return undefined;
return (e: E) => {
if (
e.currentTarget instanceof Node &&
e.target instanceof Node &&
e.currentTarget.contains(e.target)
) {
onClick(e);
}
};
}
export { guardPortalClick };

View File

@@ -92,7 +92,7 @@ export { default as SvgHashSmall } from "@opal/icons/hash-small";
export { default as SvgHash } from "@opal/icons/hash";
export { default as SvgHeadsetMic } from "@opal/icons/headset-mic";
export { default as SvgHistory } from "@opal/icons/history";
export { default as SvgHookNodes } from "@opal/icons/hook-nodes";
export { default as SvgShareWebhook } from "@opal/icons/share-webhook";
export { default as SvgHourglass } from "@opal/icons/hourglass";
export { default as SvgImage } from "@opal/icons/image";
export { default as SvgImageSmall } from "@opal/icons/image-small";

View File

@@ -1,6 +1,6 @@
import type { IconProps } from "@opal/types";
const SvgHookNodes = ({ size, ...props }: IconProps) => (
const SvgShareWebhook = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
@@ -18,4 +18,4 @@ const SvgHookNodes = ({ size, ...props }: IconProps) => (
/>
</svg>
);
export default SvgHookNodes;
export default SvgShareWebhook;

View File

@@ -11,6 +11,8 @@ import type {
OverridableExtremaSizeVariants,
ContainerSizeVariants,
ExtremaSizeVariants,
PaddingVariants,
RoundingVariants,
} from "@opal/types";
/**
@@ -88,12 +90,40 @@ const heightVariants: Record<ExtremaSizeVariants, string> = {
full: "h-full",
} as const;
// ---------------------------------------------------------------------------
// Card Variants
//
// Shared padding and rounding scales for card components (Card, SelectCard).
//
// Consumers:
// - Card (paddingVariant, roundingVariant)
// - SelectCard (paddingVariant, roundingVariant)
// ---------------------------------------------------------------------------
const cardPaddingVariants: Record<PaddingVariants, string> = {
lg: "p-6",
md: "p-4",
sm: "p-2",
xs: "p-1",
"2xs": "p-0.5",
fit: "p-0",
};
const cardRoundingVariants: Record<RoundingVariants, string> = {
lg: "rounded-16",
md: "rounded-12",
sm: "rounded-08",
xs: "rounded-04",
};
export {
type ExtremaSizeVariants,
type ContainerSizeVariants,
type OverridableExtremaSizeVariants,
type SizeVariants,
containerSizeVariants,
cardPaddingVariants,
cardRoundingVariants,
widthVariants,
heightVariants,
};

View File

@@ -7,6 +7,7 @@ import {
SvgGlobe,
SvgHardDrive,
SvgHeadsetMic,
SvgShareWebhook,
SvgKey,
SvgLock,
SvgPaintBrush,
@@ -63,6 +64,7 @@ const BUSINESS_FEATURES: PlanFeature[] = [
{ icon: SvgKey, text: "Service Account API Keys" },
{ icon: SvgHardDrive, text: "Self-hosting (Optional)" },
{ icon: SvgPaintBrush, text: "Custom Theming" },
{ icon: SvgShareWebhook, text: "Hook Extensions" },
];
const ENTERPRISE_FEATURES: PlanFeature[] = [

View File

@@ -159,6 +159,10 @@ export interface Message {
overridden_model?: string;
stopReason?: StreamStopReason | null;
// Multi-model answer generation
preferredResponseId?: number | null;
modelDisplayName?: string | null;
// new gen
packets: Packet[];
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
@@ -231,13 +235,28 @@ export interface BackendMessage {
parentMessageId: number | null;
refined_answer_improvement: boolean | null;
is_agentic: boolean | null;
// Multi-model answer generation
preferred_response_id: number | null;
model_display_name: string | null;
}
export interface MessageResponseIDInfo {
type: "message_id_info";
user_message_id: number | null;
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
}
export interface ModelResponseSlot {
message_id: number;
model_name: string;
}
export interface MultiModelMessageResponseIDInfo {
type: "multi_model_message_id_info";
user_message_id: number | null;
responses: ModelResponseSlot[];
}
export interface UserKnowledgeFilePacket {
user_files: FileDescriptor[];
}

View File

@@ -0,0 +1,126 @@
"use client";
import { useCallback } from "react";
import { Button } from "@opal/components";
import { Text } from "@opal/components";
import { ContentAction } from "@opal/layouts";
import { SvgEyeOff, SvgX } from "@opal/icons";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import AgentMessage, {
AgentMessageProps,
} from "@/app/app/message/messageComponents/AgentMessage";
import { cn } from "@/lib/utils";
import { markdown } from "@opal/utils";
export interface MultiModelPanelProps {
/** Provider name for icon lookup */
provider: string;
/** Model name for icon lookup and display */
modelName: string;
/** Display-friendly model name */
displayName: string;
/** Whether this panel is the preferred/selected response */
isPreferred: boolean;
/** Whether this panel is currently hidden */
isHidden: boolean;
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
isNonPreferredInSelection: boolean;
/** Callback when user clicks this panel to select as preferred */
onSelect: () => void;
/** Callback to hide/show this panel */
onToggleVisibility: () => void;
/** Props to pass through to AgentMessage */
agentMessageProps: AgentMessageProps;
}
/**
* A single model's response panel within the multi-model view.
*
* Renders in two states:
* - **Hidden** — compact header strip only (provider icon + strikethrough name + show button).
* - **Visible** — full header plus `AgentMessage` body. Clicking anywhere on a
* visible non-preferred panel marks it as preferred.
*
* The `isNonPreferredInSelection` flag disables pointer events on the body and
* hides the footer so the panel acts as a passive comparison surface.
*/
export default function MultiModelPanel({
provider,
modelName,
displayName,
isPreferred,
isHidden,
isNonPreferredInSelection,
onSelect,
onToggleVisibility,
agentMessageProps,
}: MultiModelPanelProps) {
const ProviderIcon = getProviderIcon(provider, modelName);
const handlePanelClick = useCallback(() => {
if (!isHidden) onSelect();
}, [isHidden, onSelect]);
const header = (
<div
className={cn(
"rounded-12",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
)}
>
<ContentAction
sizePreset="main-ui"
variant="body"
paddingVariant="lg"
icon={ProviderIcon}
title={isHidden ? markdown(`~~${displayName}~~`) : displayName}
rightChildren={
<div className="flex items-center gap-1 px-2">
{isPreferred && (
<span className="text-action-link-05 shrink-0">
<Text font="secondary-body" color="inherit" nowrap>
Preferred Response
</Text>
</span>
)}
{!isPreferred && (
<Button
prominence="tertiary"
icon={isHidden ? SvgEyeOff : SvgX}
size="md"
onClick={(e) => {
e.stopPropagation();
onToggleVisibility();
}}
tooltip={isHidden ? "Show response" : "Hide response"}
/>
)}
</div>
}
/>
</div>
);
// Hidden/collapsed panel — just the header row
if (isHidden) {
return header;
}
return (
<div
className={cn(
"flex flex-col gap-3 min-w-0 cursor-pointer rounded-16 transition-colors",
!isPreferred && "hover:bg-background-tint-02"
)}
onClick={handlePanelClick}
>
{header}
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
/>
</div>
</div>
);
}

View File

@@ -0,0 +1,374 @@
"use client";
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
import { MultiModelResponse } from "@/app/app/message/interfaces";
import { cn } from "@/lib/utils";
export interface MultiModelResponseViewProps {
responses: MultiModelResponse[];
chatState: FullChatState;
llmManager: LlmManager | null;
onRegenerate?: RegenerationFactory;
parentMessage?: Message | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
/** Called whenever the set of hidden panel indices changes */
onHiddenPanelsChange?: (hidden: Set<number>) => void;
}
// How many pixels of a non-preferred panel are visible at the viewport edge
const PEEK_W = 64;
// Uniform panel width used in the selection-mode carousel
const SELECTION_PANEL_W = 400;
// Compact width for hidden panels in the carousel track
const HIDDEN_PANEL_W = 220;
// Generation-mode panel widths (from Figma)
const GEN_PANEL_W_2 = 640; // 2 panels side-by-side
const GEN_PANEL_W_3 = 436; // 3 panels side-by-side
// Gap between panels — matches CSS gap-6 (24px)
const PANEL_GAP = 24;
// Minimum panel width before horizontal scroll kicks in
const MIN_PANEL_W = 300;
/**
* Renders N model responses side-by-side with two layout modes:
*
* **Generation mode** — equal-width panels in a horizontally-scrollable row.
* Panel width is determined by the number of visible (non-hidden) panels.
*
* **Selection mode** — activated when the user clicks a panel to mark it as
* preferred. All panels (including hidden ones) sit in a fixed-width carousel
* track. A CSS `translateX` transform slides the track so the preferred panel
* is centered in the viewport; the other panels peek in from the edges through
* a mask gradient. Non-preferred visible panels are height-capped to the
* preferred panel's measured height, dimmed at 50% opacity, and receive a
* bottom fade-out overlay.
*
* Hidden panels render as a compact header-only strip at `HIDDEN_PANEL_W` in
* both modes and are excluded from layout width calculations.
*/
export default function MultiModelResponseView({
responses,
chatState,
llmManager,
onRegenerate,
parentMessage,
otherMessagesCanSwitchTo,
onMessageSelection,
onHiddenPanelsChange,
}: MultiModelResponseViewProps) {
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
// Controls animation: false = panels at start position, true = panels at peek position
const [selectionEntered, setSelectionEntered] = useState(false);
// Measures the overflow-hidden carousel container for responsive preferred-panel sizing.
const [trackContainerW, setTrackContainerW] = useState(0);
const roRef = useRef<ResizeObserver | null>(null);
const trackContainerRef = useCallback((el: HTMLDivElement | null) => {
carouselContainerRef.current = el;
if (roRef.current) {
roRef.current.disconnect();
roRef.current = null;
}
if (!el) return;
const ro = new ResizeObserver(([entry]) => {
setTrackContainerW(entry?.contentRect.width ?? 0);
});
ro.observe(el);
setTrackContainerW(el.offsetWidth);
roRef.current = ro;
}, []);
// Measures the preferred panel's height to cap non-preferred panels in selection mode.
const [preferredPanelHeight, setPreferredPanelHeight] = useState<
number | null
>(null);
const preferredRoRef = useRef<ResizeObserver | null>(null);
const trackRef = useRef<HTMLDivElement | null>(null);
const carouselContainerRef = useRef<HTMLDivElement | null>(null);
// Tracks which non-preferred panels overflow the preferred height cap
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
new Set()
);
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
if (preferredRoRef.current) {
preferredRoRef.current.disconnect();
preferredRoRef.current = null;
}
if (!el) {
setPreferredPanelHeight(null);
return;
}
const ro = new ResizeObserver(([entry]) => {
setPreferredPanelHeight(entry?.contentRect.height ?? 0);
});
ro.observe(el);
setPreferredPanelHeight(el.offsetHeight);
preferredRoRef.current = ro;
}, []);
const isGenerating = useMemo(
() => responses.some((r) => r.isGenerating),
[responses]
);
// Non-hidden responses — used for layout width decisions and selection-mode gating
const visibleResponses = useMemo(
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const toggleVisibility = useCallback(
(modelIndex: number) => {
setHiddenPanels((prev) => {
const next = new Set(prev);
if (next.has(modelIndex)) {
next.delete(modelIndex);
} else {
// Don't hide the last visible panel
const visibleCount = responses.length - next.size;
if (visibleCount <= 1) return prev;
next.add(modelIndex);
}
onHiddenPanelsChange?.(next);
return next;
});
},
[responses.length, onHiddenPanelsChange]
);
const handleSelectPreferred = useCallback(
(modelIndex: number) => {
if (isGenerating) return;
setPreferredIndex(modelIndex);
const response = responses[modelIndex];
if (!response) return;
if (onMessageSelection) {
onMessageSelection(response.nodeId);
}
},
[isGenerating, responses, onMessageSelection]
);
// Clear preferred selection when generation starts
useEffect(() => {
if (isGenerating) {
setPreferredIndex(null);
}
}, [isGenerating]);
// Selection mode when preferred is set, not generating, and at least 2 visible panels
const showSelectionMode =
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
// Trigger the slide-out animation one frame after entering selection mode
useEffect(() => {
if (!showSelectionMode) {
setSelectionEntered(false);
return;
}
const raf = requestAnimationFrame(() => setSelectionEntered(true));
return () => cancelAnimationFrame(raf);
}, [showSelectionMode]);
// Build panel props — isHidden reflects actual hidden state
const buildPanelProps = useCallback(
(response: MultiModelResponse, isNonPreferred: boolean) => ({
modelIndex: response.modelIndex,
provider: response.provider,
modelName: response.modelName,
displayName: response.displayName,
isPreferred: preferredIndex === response.modelIndex,
isHidden: hiddenPanels.has(response.modelIndex),
isNonPreferredInSelection: isNonPreferred,
onSelect: () => handleSelectPreferred(response.modelIndex),
onToggleVisibility: () => toggleVisibility(response.modelIndex),
agentMessageProps: {
rawPackets: response.packets,
packetCount: response.packetCount,
chatState,
nodeId: response.nodeId,
messageId: response.messageId,
currentFeedback: response.currentFeedback,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
},
}),
[
preferredIndex,
hiddenPanels,
handleSelectPreferred,
toggleVisibility,
chatState,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
]
);
if (showSelectionMode) {
// ── Selection Layout (transform-based carousel) ──
//
// All panels (including hidden) sit in the track at their original A/B/C positions.
// Hidden panels use HIDDEN_PANEL_W; non-preferred use SELECTION_PANEL_W;
// preferred uses dynamicPrefW (up to GEN_PANEL_W_2).
const preferredIdx = responses.findIndex(
(r) => r.modelIndex === preferredIndex
);
const n = responses.length;
const dynamicPrefW =
trackContainerW > 0
? Math.min(trackContainerW - 2 * (PEEK_W + PANEL_GAP), GEN_PANEL_W_2)
: GEN_PANEL_W_2;
const selectionWidths = responses.map((r, i) => {
if (hiddenPanels.has(r.modelIndex)) return HIDDEN_PANEL_W;
if (i === preferredIdx) return dynamicPrefW;
return SELECTION_PANEL_W;
});
const panelLeftEdges = selectionWidths.reduce<number[]>((acc, w, i) => {
acc.push(i === 0 ? 0 : acc[i - 1]! + selectionWidths[i - 1]! + PANEL_GAP);
return acc;
}, []);
const preferredCenterInTrack =
panelLeftEdges[preferredIdx]! + selectionWidths[preferredIdx]! / 2;
// Start position: hidden panels at HIDDEN_PANEL_W, visible at SELECTION_PANEL_W
const uniformTrackW =
responses.reduce(
(sum, r) =>
sum +
(hiddenPanels.has(r.modelIndex) ? HIDDEN_PANEL_W : SELECTION_PANEL_W),
0
) +
(n - 1) * PANEL_GAP;
const trackTransform = selectionEntered
? `translateX(${trackContainerW / 2 - preferredCenterInTrack}px)`
: `translateX(${(trackContainerW - uniformTrackW) / 2}px)`;
return (
<div
ref={trackContainerRef}
className="w-full overflow-hidden"
style={{
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
}}
>
<div
ref={trackRef}
className="flex items-start"
style={{
gap: `${PANEL_GAP}px`,
transition: selectionEntered
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
: "none",
transform: trackTransform,
}}
>
{responses.map((r, i) => {
const isHidden = hiddenPanels.has(r.modelIndex);
const isPref = r.modelIndex === preferredIndex;
const isNonPref = !isHidden && !isPref;
const finalW = selectionWidths[i]!;
const startW = isHidden ? HIDDEN_PANEL_W : SELECTION_PANEL_W;
const capped = isNonPref && preferredPanelHeight != null;
const overflows = capped && overflowingPanels.has(r.modelIndex);
return (
<div
key={r.modelIndex}
ref={(el) => {
if (isPref) preferredPanelRef(el);
if (capped && el) {
// Check if content overflows the preferred height cap
const doesOverflow = el.scrollHeight > el.clientHeight;
setOverflowingPanels((prev) => {
const had = prev.has(r.modelIndex);
if (doesOverflow === had) return prev;
const next = new Set(prev);
if (doesOverflow) next.add(r.modelIndex);
else next.delete(r.modelIndex);
return next;
});
}
}}
style={{
width: `${selectionEntered ? finalW : startW}px`,
flexShrink: 0,
transition: selectionEntered
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
: "none",
maxHeight: capped ? preferredPanelHeight : undefined,
overflow: capped ? "hidden" : undefined,
position: capped ? "relative" : undefined,
}}
>
<div className={cn(isNonPref && "opacity-50")}>
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
</div>
{overflows && (
<div
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
style={{
background:
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
}}
/>
)}
</div>
);
})}
</div>
</div>
);
}
// ── Generation Layout (equal panels side-by-side) ──
// Panel width based on number of visible (non-hidden) panels.
const panelWidth =
visibleResponses.length <= 2 ? GEN_PANEL_W_2 : GEN_PANEL_W_3;
return (
<div className="overflow-x-auto">
<div className="flex gap-6 items-start w-fit mx-auto">
{responses.map((r) => {
const isHidden = hiddenPanels.has(r.modelIndex);
return (
<div
key={r.modelIndex}
style={
isHidden
? {
width: HIDDEN_PANEL_W,
minWidth: HIDDEN_PANEL_W,
maxWidth: HIDDEN_PANEL_W,
flexShrink: 0,
overflow: "hidden" as const,
}
: { width: panelWidth, minWidth: MIN_PANEL_W }
}
>
<MultiModelPanel {...buildPanelProps(r, false)} />
</div>
);
})}
</div>
</div>
);
}

View File

@@ -0,0 +1,16 @@
import { Packet } from "@/app/app/services/streamingModels";
import { FeedbackType } from "@/app/app/interfaces";
export interface MultiModelResponse {
modelIndex: number;
provider: string;
modelName: string;
displayName: string;
packets: Packet[];
packetCount: number;
nodeId: number;
messageId?: number;
currentFeedback?: FeedbackType | null;
isGenerating?: boolean;
}

View File

@@ -49,6 +49,8 @@ export interface AgentMessageProps {
parentMessage?: Message | null;
// Duration in seconds for processing this message (agent messages only)
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -76,7 +78,8 @@ function arePropsEqual(
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
prev.llmManager?.isLoadingProviders ===
next.llmManager?.isLoadingProviders &&
prev.processingDurationSeconds === next.processingDurationSeconds
prev.processingDurationSeconds === next.processingDurationSeconds &&
prev.hideFooter === next.hideFooter
// Skip: chatState.regenerate, chatState.setPresentingDocument,
// most of llmManager, onMessageSelection (function/object props)
);
@@ -95,6 +98,7 @@ const AgentMessage = React.memo(function AgentMessage({
onRegenerate,
parentMessage,
processingDurationSeconds,
hideFooter,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -326,7 +330,7 @@ const AgentMessage = React.memo(function AgentMessage({
</div>
{/* Feedback buttons - only show when streaming and rendering complete */}
{isComplete && (
{isComplete && !hideFooter && (
<MessageToolbar
nodeId={nodeId}
messageId={messageId}

View File

@@ -12,6 +12,7 @@ import {
FileChatDisplay,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
ResearchType,
RetrievalType,
StreamingError,
@@ -96,6 +97,7 @@ export type PacketType =
| FileChatDisplay
| StreamingError
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamStopInfo
| UserKnowledgeFilePacket
| Packet;
@@ -109,6 +111,13 @@ export type MessageOrigin =
| "slackbot"
| "unknown";
export interface LLMOverride {
model_provider: string;
model_version: string;
temperature?: number;
display_name?: string;
}
export interface SendMessageParams {
message: string;
fileDescriptors?: FileDescriptor[];
@@ -124,6 +133,8 @@ export interface SendMessageParams {
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// Multi-model: send multiple LLM overrides for parallel generation
llmOverrides?: LLMOverride[];
// Origin of the message for telemetry tracking
origin?: MessageOrigin;
// Additional context injected into the LLM call but not stored/shown in chat.
@@ -144,6 +155,7 @@ export async function* sendMessage({
modelProvider,
modelVersion,
temperature,
llmOverrides,
origin,
additionalContext,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
@@ -165,6 +177,8 @@ export async function* sendMessage({
model_version: modelVersion,
}
: null,
// Multi-model: list of LLM overrides for parallel generation
llm_overrides: llmOverrides ?? null,
// Default to "unknown" for consistency with backend; callers should set explicitly
origin: origin ?? "unknown",
additional_context: additionalContext ?? null,
@@ -182,12 +196,27 @@ export async function* sendMessage({
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
const data = await response.json().catch(() => ({}));
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
}
yield* handleSSEStream<PacketType>(response, signal);
}
export async function setPreferredResponse(
userMessageId: number,
preferredResponseId: number
): Promise<Response> {
return fetch("/api/chat/set-preferred-response", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_message_id: userMessageId,
preferred_response_id: preferredResponseId,
}),
});
}
export async function nameChatSession(chatSessionId: string) {
const response = await fetch("/api/chat/rename-chat-session", {
method: "PUT",
@@ -357,6 +386,9 @@ export function processRawChatHistory(
overridden_model: messageInfo.overridden_model,
packets: packetsForMessage || [],
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
// Multi-model answer generation
preferredResponseId: messageInfo.preferred_response_id ?? null,
modelDisplayName: messageInfo.model_display_name ?? null,
};
messages.set(messageInfo.message_id, message);

View File

@@ -403,6 +403,7 @@ export interface Placement {
turn_index: number;
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
sub_turn_index?: number | null;
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
}
// Packet wrapper for streaming objects

View File

@@ -1,412 +0,0 @@
"use client";
import { useState } from "react";
import { toast } from "@/hooks/useToast";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { cn } from "@/lib/utils";
import { markdown } from "@opal/utils";
import { Content } from "@opal/layouts";
import Card from "@/refresh-components/cards/Card";
import Text from "@/refresh-components/texts/Text";
import { Section } from "@/layouts/general-layouts";
import {
SvgExternalLink,
SvgPlug,
SvgRefreshCw,
SvgSettings,
SvgTrash,
SvgUnplug,
} from "@opal/icons";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import type {
HookPointMeta,
HookResponse,
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
import {
activateHook,
deactivateHook,
deleteHook,
validateHook,
} from "@/ee/refresh-pages/admin/HooksPage/svc";
import { getHookPointIcon } from "@/ee/refresh-pages/admin/HooksPage/hookPointIcons";
import HookStatusPopover from "@/ee/refresh-pages/admin/HooksPage/HookStatusPopover";
// ---------------------------------------------------------------------------
// Sub-component: disconnect confirmation modal
// ---------------------------------------------------------------------------
interface DisconnectConfirmModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
hook: HookResponse;
onDisconnect: () => void;
onDisconnectAndDelete: () => void;
}
function DisconnectConfirmModal({
open,
onOpenChange,
hook,
onDisconnect,
onDisconnectAndDelete,
}: DisconnectConfirmModalProps) {
return (
<Modal open={open} onOpenChange={onOpenChange}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={(props) => (
<SvgUnplug {...props} className="text-action-danger-05" />
)}
title={`Disconnect ${hook.name}`}
onClose={() => onOpenChange(false)}
/>
<Modal.Body>
<div className="flex flex-col gap-4">
<Text mainUiBody text03>
Onyx will stop calling this endpoint for hook{" "}
<strong>
<em>{hook.name}</em>
</strong>
. In-flight requests will continue to run. The external endpoint
may still retain data previously sent to it. You can reconnect
this hook later if needed.
</Text>
<Text mainUiBody text03>
You can also delete this hook. Deletion cannot be undone.
</Text>
</div>
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Button
prominence="secondary"
onClick={() => onOpenChange(false)}
>
Cancel
</Button>
}
submit={
<div className="flex items-center gap-2">
<Button
variant="danger"
prominence="secondary"
onClick={onDisconnectAndDelete}
>
Disconnect &amp; Delete
</Button>
<Button
variant="danger"
prominence="primary"
onClick={onDisconnect}
>
Disconnect
</Button>
</div>
}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}
// ---------------------------------------------------------------------------
// Sub-component: delete confirmation modal
// ---------------------------------------------------------------------------
interface DeleteConfirmModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
hook: HookResponse;
onDelete: () => void;
}
function DeleteConfirmModal({
open,
onOpenChange,
hook,
onDelete,
}: DeleteConfirmModalProps) {
return (
<Modal open={open} onOpenChange={onOpenChange}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={(props) => (
<SvgTrash {...props} className="text-action-danger-05" />
)}
title={`Delete ${hook.name}`}
onClose={() => onOpenChange(false)}
/>
<Modal.Body>
<div className="flex flex-col gap-4">
<Text mainUiBody text03>
Hook{" "}
<strong>
<em>{hook.name}</em>
</strong>{" "}
will be permanently removed from this hook point. The external
endpoint may still retain data previously sent to it.
</Text>
<Text mainUiBody text03>
Deletion cannot be undone.
</Text>
</div>
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Button
prominence="secondary"
onClick={() => onOpenChange(false)}
>
Cancel
</Button>
}
submit={
<Button variant="danger" prominence="primary" onClick={onDelete}>
Delete
</Button>
}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}
// ---------------------------------------------------------------------------
// ConnectedHookCard
// ---------------------------------------------------------------------------
export interface ConnectedHookCardProps {
hook: HookResponse;
spec: HookPointMeta | undefined;
onEdit: () => void;
onDeleted: () => void;
onToggled: (updated: HookResponse) => void;
}
export default function ConnectedHookCard({
hook,
spec,
onEdit,
onDeleted,
onToggled,
}: ConnectedHookCardProps) {
const [isBusy, setIsBusy] = useState(false);
const [disconnectConfirmOpen, setDisconnectConfirmOpen] = useState(false);
const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false);
async function handleDelete() {
setDeleteConfirmOpen(false);
setIsBusy(true);
try {
await deleteHook(hook.id);
onDeleted();
} catch (err) {
console.error("Failed to delete hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to delete hook."
);
} finally {
setIsBusy(false);
}
}
async function handleActivate() {
setIsBusy(true);
try {
const updated = await activateHook(hook.id);
onToggled(updated);
} catch (err) {
console.error("Failed to reconnect hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to reconnect hook."
);
} finally {
setIsBusy(false);
}
}
async function handleDeactivate() {
setDisconnectConfirmOpen(false);
setIsBusy(true);
try {
const updated = await deactivateHook(hook.id);
onToggled(updated);
} catch (err) {
console.error("Failed to deactivate hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to deactivate hook."
);
} finally {
setIsBusy(false);
}
}
async function handleDisconnectAndDelete() {
setDisconnectConfirmOpen(false);
setIsBusy(true);
try {
const deactivated = await deactivateHook(hook.id);
onToggled(deactivated);
await deleteHook(hook.id);
onDeleted();
} catch (err) {
console.error("Failed to disconnect hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to disconnect hook."
);
} finally {
setIsBusy(false);
}
}
async function handleValidate() {
setIsBusy(true);
try {
const result = await validateHook(hook.id);
if (result.status === "passed") {
toast.success("Hook validated successfully.");
} else {
toast.error(
result.error_message ?? `Validation failed: ${result.status}`
);
}
} catch (err) {
console.error("Failed to validate hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to validate hook."
);
} finally {
setIsBusy(false);
}
}
const HookIcon = getHookPointIcon(hook.hook_point);
return (
<>
<DisconnectConfirmModal
open={disconnectConfirmOpen}
onOpenChange={setDisconnectConfirmOpen}
hook={hook}
onDisconnect={handleDeactivate}
onDisconnectAndDelete={handleDisconnectAndDelete}
/>
<DeleteConfirmModal
open={deleteConfirmOpen}
onOpenChange={setDeleteConfirmOpen}
hook={hook}
onDelete={handleDelete}
/>
<Card
variant="primary"
padding={0.5}
gap={0}
className={cn(
"hover:border-border-02",
!hook.is_active && "!bg-background-neutral-02"
)}
>
<div className="w-full flex flex-row">
<div className="flex-1 p-2">
<Content
sizePreset="main-ui"
variant="section"
icon={HookIcon}
title={!hook.is_active ? markdown(`~~${hook.name}~~`) : hook.name}
description={`Hook Point: ${
spec?.display_name ?? hook.hook_point
}`}
/>
{spec?.docs_url && (
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="pl-6 flex items-center gap-1 w-fit"
>
<span className="underline font-secondary-body text-text-03">
Documentation
</span>
<SvgExternalLink size={12} className="shrink-0" />
</a>
)}
</div>
<Section
flexDirection="column"
alignItems="end"
width="fit"
height="fit"
gap={0}
>
<div className="flex items-center gap-1">
{hook.is_active ? (
<HookStatusPopover hook={hook} spec={spec} isBusy={isBusy} />
) : (
<div
className={cn(
"flex items-center gap-1 p-2",
isBusy ? "opacity-50 pointer-events-none" : "cursor-pointer"
)}
onClick={handleActivate}
>
<Text mainUiAction text03>
Reconnect
</Text>
<SvgPlug size={16} className="text-text-03 shrink-0" />
</div>
)}
</div>
<Disabled disabled={isBusy}>
<div className="flex items-center gap-0.5 pl-1 pr-1 pb-1">
{hook.is_active ? (
<>
<Button
prominence="tertiary"
size="sm"
icon={SvgUnplug}
onClick={() => setDisconnectConfirmOpen(true)}
tooltip="Disconnect Hook"
aria-label="Deactivate hook"
/>
<Button
prominence="tertiary"
size="sm"
icon={SvgRefreshCw}
onClick={handleValidate}
tooltip="Test Connection"
aria-label="Re-validate hook"
/>
</>
) : (
<Button
prominence="tertiary"
size="sm"
icon={SvgTrash}
onClick={() => setDeleteConfirmOpen(true)}
tooltip="Delete"
aria-label="Delete hook"
/>
)}
<Button
prominence="tertiary"
size="sm"
icon={SvgSettings}
onClick={onEdit}
tooltip="Manage"
aria-label="Configure hook"
/>
</div>
</Disabled>
</Section>
</div>
</Card>
</>
);
}

View File

@@ -1,21 +1,23 @@
"use client";
import { useState } from "react";
import { Formik, Form, useFormikContext } from "formik";
import * as Yup from "yup";
import { Button, Text } from "@opal/components";
import { Disabled } from "@opal/core";
import {
SvgCheckCircle,
SvgHookNodes,
SvgShareWebhook,
SvgLoader,
SvgRevert,
} from "@opal/icons";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { FormField } from "@/refresh-components/form/FormField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import { Section } from "@/layouts/general-layouts";
import { ContentAction } from "@opal/layouts";
import { Content, ContentAction } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
import {
createHook,
@@ -37,7 +39,6 @@ import type {
// ---------------------------------------------------------------------------
interface HookFormModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
/** When provided, the modal is in edit mode for this hook. */
hook?: HookResponse;
@@ -50,7 +51,12 @@ interface HookFormModalProps {
// Helpers
// ---------------------------------------------------------------------------
function buildInitialState(
const MAX_TIMEOUT_SECONDS = 600;
const SOFT_DESCRIPTION =
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
function buildInitialValues(
hook: HookResponse | undefined,
spec: HookPointMeta | undefined
): HookFormState {
@@ -72,172 +78,95 @@ function buildInitialState(
};
}
const SOFT_DESCRIPTION =
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
function buildValidationSchema(isEdit: boolean) {
return Yup.object().shape({
name: Yup.string().trim().required("Display name cannot be empty."),
endpoint_url: Yup.string().trim().required("Endpoint URL cannot be empty."),
api_key: isEdit
? Yup.string()
: Yup.string().trim().required("API key cannot be empty."),
timeout_seconds: Yup.string()
.required("Timeout is required.")
.test(
"valid-timeout",
`Must be greater than 0 and at most ${MAX_TIMEOUT_SECONDS} seconds.`,
(val) => {
const num = parseFloat(val ?? "");
return !isNaN(num) && num > 0 && num <= MAX_TIMEOUT_SECONDS;
}
),
});
}
const MAX_TIMEOUT_SECONDS = 600;
// ---------------------------------------------------------------------------
// Timeout field (needs access to spec for revert button)
// ---------------------------------------------------------------------------
interface TimeoutFieldProps {
spec: HookPointMeta | undefined;
}
function TimeoutField({ spec }: TimeoutFieldProps) {
const { values, setFieldValue, isSubmitting } =
useFormikContext<HookFormState>();
return (
<InputLayouts.Vertical
name="timeout_seconds"
title="Timeout"
suffix="(seconds)"
subDescription={`Maximum time Onyx will wait for the endpoint to respond before applying the fail strategy. Must be greater than 0 and at most ${MAX_TIMEOUT_SECONDS} seconds.`}
>
<div className="[&_input]:!font-main-ui-mono [&_input::placeholder]:!font-main-ui-mono [&_input]:![appearance:textfield] [&_input::-webkit-outer-spin-button]:!appearance-none [&_input::-webkit-inner-spin-button]:!appearance-none w-full">
<InputTypeInField
name="timeout_seconds"
type="number"
placeholder={spec ? String(spec.default_timeout_seconds) : undefined}
variant={isSubmitting ? "disabled" : undefined}
showClearButton={false}
rightSection={
spec?.default_timeout_seconds !== undefined &&
values.timeout_seconds !== String(spec.default_timeout_seconds) ? (
<Button
prominence="tertiary"
size="xs"
icon={SvgRevert}
tooltip="Revert to Default"
onClick={() =>
setFieldValue(
"timeout_seconds",
String(spec.default_timeout_seconds)
)
}
disabled={isSubmitting}
/>
) : undefined
}
/>
</div>
</InputLayouts.Vertical>
);
}
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
export default function HookFormModal({
open,
onOpenChange,
hook,
spec,
onSuccess,
}: HookFormModalProps) {
const isEdit = !!hook;
const [form, setForm] = useState<HookFormState>(() =>
buildInitialState(hook, spec)
);
const [isSubmitting, setIsSubmitting] = useState(false);
const [isConnected, setIsConnected] = useState(false);
// Tracks whether the user explicitly cleared the API key field in edit mode.
// - false + empty field → key unchanged (omitted from PATCH)
// - true + empty field → key cleared (api_key: null sent to backend)
// - false + non-empty → new key provided (new value sent to backend)
const [apiKeyCleared, setApiKeyCleared] = useState(false);
const [touched, setTouched] = useState({
name: false,
endpoint_url: false,
api_key: false,
});
const [apiKeyServerError, setApiKeyServerError] = useState(false);
const [endpointServerError, setEndpointServerError] = useState<string | null>(
null
);
const [timeoutServerError, setTimeoutServerError] = useState(false);
function touch(key: keyof typeof touched) {
setTouched((prev) => ({ ...prev, [key]: true }));
}
const initialValues = buildInitialValues(hook, spec);
const validationSchema = buildValidationSchema(isEdit);
function handleOpenChange(next: boolean) {
if (!next) {
if (isSubmitting) return;
setTimeout(() => {
setForm(buildInitialState(hook, spec));
setIsConnected(false);
setApiKeyCleared(false);
setTouched({ name: false, endpoint_url: false, api_key: false });
setApiKeyServerError(false);
setEndpointServerError(null);
setTimeoutServerError(false);
}, 200);
}
onOpenChange(next);
}
function set<K extends keyof HookFormState>(key: K, value: HookFormState[K]) {
setForm((prev) => ({ ...prev, [key]: value }));
}
const timeoutNum = parseFloat(form.timeout_seconds);
const isTimeoutValid =
!isNaN(timeoutNum) && timeoutNum > 0 && timeoutNum <= MAX_TIMEOUT_SECONDS;
const isValid =
form.name.trim().length > 0 &&
form.endpoint_url.trim().length > 0 &&
isTimeoutValid &&
(isEdit || form.api_key.trim().length > 0);
const nameError = touched.name && !form.name.trim();
const endpointEmptyError = touched.endpoint_url && !form.endpoint_url.trim();
const endpointFieldError = endpointEmptyError
? "Endpoint URL cannot be empty."
: endpointServerError ?? undefined;
const apiKeyEmptyError = !isEdit && touched.api_key && !form.api_key.trim();
const apiKeyFieldError = apiKeyEmptyError
? "API key cannot be empty."
: apiKeyServerError
? "Invalid API key."
: undefined;
function handleTimeoutBlur() {
if (!isTimeoutValid) {
const fallback = hook?.timeout_seconds ?? spec?.default_timeout_seconds;
if (fallback !== undefined) {
set("timeout_seconds", String(fallback));
if (timeoutServerError) setTimeoutServerError(false);
}
}
}
const hasChanges =
isEdit && hook
? form.name !== hook.name ||
form.endpoint_url !== (hook.endpoint_url ?? "") ||
form.fail_strategy !== hook.fail_strategy ||
timeoutNum !== hook.timeout_seconds ||
form.api_key.trim().length > 0 ||
apiKeyCleared
: true;
async function handleSubmit() {
if (!isValid) return;
setIsSubmitting(true);
try {
let result: HookResponse;
if (isEdit && hook) {
const req: HookUpdateRequest = {};
if (form.name !== hook.name) req.name = form.name;
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
req.endpoint_url = form.endpoint_url;
if (form.fail_strategy !== hook.fail_strategy)
req.fail_strategy = form.fail_strategy;
if (timeoutNum !== hook.timeout_seconds)
req.timeout_seconds = timeoutNum;
if (form.api_key.trim().length > 0) {
req.api_key = form.api_key;
} else if (apiKeyCleared) {
req.api_key = null;
}
if (Object.keys(req).length === 0) {
setIsSubmitting(false);
handleOpenChange(false);
return;
}
result = await updateHook(hook.id, req);
} else {
if (!spec) {
toast.error("No hook point specified.");
setIsSubmitting(false);
return;
}
result = await createHook({
name: form.name,
hook_point: spec.hook_point,
endpoint_url: form.endpoint_url,
...(form.api_key ? { api_key: form.api_key } : {}),
fail_strategy: form.fail_strategy,
timeout_seconds: timeoutNum,
});
}
toast.success(isEdit ? "Hook updated." : "Hook created.");
onSuccess(result);
if (!isEdit) {
setIsConnected(true);
await new Promise((resolve) => setTimeout(resolve, 500));
}
setIsSubmitting(false);
handleOpenChange(false);
} catch (err) {
if (err instanceof HookAuthError) {
setApiKeyServerError(true);
} else if (err instanceof HookTimeoutError) {
setTimeoutServerError(true);
} else if (err instanceof HookConnectError) {
setEndpointServerError(err.message || "Could not connect to endpoint.");
} else {
toast.error(
err instanceof Error ? err.message : "Something went wrong."
);
}
setIsSubmitting(false);
}
function handleClose() {
onOpenChange(false);
}
const hookPointDisplayName =
@@ -245,314 +174,291 @@ export default function HookFormModal({
const hookPointDescription = spec?.description;
const docsUrl = spec?.docs_url;
const failStrategyDescription =
form.fail_strategy === "soft"
? SOFT_DESCRIPTION
: spec?.fail_hard_description;
return (
<Modal open={open} onOpenChange={handleOpenChange}>
<Modal open onOpenChange={(open) => !open && handleClose()}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={SvgHookNodes}
title={isEdit ? "Manage Hook Extension" : "Set Up Hook Extension"}
description={
isEdit
? undefined
: "Connect an external API endpoint to extend the hook point."
}
onClose={() => handleOpenChange(false)}
/>
<Modal.Body>
{/* Hook point section header */}
<ContentAction
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
title={hookPointDisplayName}
description={hookPointDescription}
rightChildren={
<Section
flexDirection="column"
alignItems="end"
width="fit"
height="fit"
gap={0.25}
>
<div className="flex items-center gap-1">
<SvgHookNodes
style={{ width: "1rem", height: "1rem" }}
className="text-text-03 shrink-0"
/>
<Text font="secondary-body" color="text-03">
Hook Point
</Text>
</div>
{docsUrl && (
<a
href={docsUrl}
target="_blank"
rel="noopener noreferrer"
className="underline"
>
<Text font="secondary-body" color="text-03">
Documentation
</Text>
</a>
)}
</Section>
}
/>
<FormField className="w-full" state={nameError ? "error" : "idle"}>
<FormField.Label>Display Name</FormField.Label>
<FormField.Control>
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeIn
value={form.name}
onChange={(e) => set("name", e.target.value)}
onBlur={() => touch("name")}
placeholder="Name your extension at this hook point"
variant={
isSubmitting ? "disabled" : nameError ? "error" : undefined
}
/>
</div>
</FormField.Control>
<FormField.Message
messages={{ error: "Display name cannot be empty." }}
/>
</FormField>
<FormField className="w-full">
<FormField.Label>Fail Strategy</FormField.Label>
<FormField.Control>
<InputSelect
value={form.fail_strategy}
onValueChange={(v) =>
set("fail_strategy", v as HookFailStrategy)
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount
onSubmit={async (values, helpers) => {
try {
let result: HookResponse;
if (isEdit && hook) {
const req: HookUpdateRequest = {};
if (values.name !== hook.name) req.name = values.name;
if (values.endpoint_url !== (hook.endpoint_url ?? ""))
req.endpoint_url = values.endpoint_url;
if (values.fail_strategy !== hook.fail_strategy)
req.fail_strategy = values.fail_strategy;
const timeoutNum = parseFloat(values.timeout_seconds);
if (timeoutNum !== hook.timeout_seconds)
req.timeout_seconds = timeoutNum;
if (values.api_key.trim().length > 0) {
req.api_key = values.api_key;
} else if (apiKeyCleared) {
req.api_key = null;
}
disabled={isSubmitting}
>
<InputSelect.Trigger placeholder="Select strategy" />
<InputSelect.Content>
<InputSelect.Item value="soft">
Log Error and Continue
{spec?.default_fail_strategy === "soft" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
<InputSelect.Item value="hard">
Block Pipeline on Failure
{spec?.default_fail_strategy === "hard" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</FormField.Control>
<FormField.Description>
{failStrategyDescription}
</FormField.Description>
</FormField>
if (Object.keys(req).length === 0) {
handleClose();
return;
}
result = await updateHook(hook.id, req);
} else {
if (!spec) {
toast.error("No hook point specified.");
return;
}
result = await createHook({
name: values.name,
hook_point: spec.hook_point,
endpoint_url: values.endpoint_url,
...(values.api_key ? { api_key: values.api_key } : {}),
fail_strategy: values.fail_strategy,
timeout_seconds: parseFloat(values.timeout_seconds),
});
}
toast.success(isEdit ? "Hook updated." : "Hook created.");
onSuccess(result);
if (!isEdit) {
setIsConnected(true);
await new Promise((resolve) => setTimeout(resolve, 500));
}
handleClose();
} catch (err) {
if (err instanceof HookAuthError) {
helpers.setFieldError("api_key", "Invalid API key.");
} else if (err instanceof HookTimeoutError) {
helpers.setFieldError(
"timeout_seconds",
"Connection timed out. Try increasing the timeout."
);
} else if (err instanceof HookConnectError) {
helpers.setFieldError(
"endpoint_url",
err.message || "Could not connect to endpoint."
);
} else {
toast.error(
err instanceof Error ? err.message : "Something went wrong."
);
}
} finally {
helpers.setSubmitting(false);
}
}}
>
{({ values, setFieldValue, isSubmitting, isValid, dirty }) => {
const failStrategyDescription =
values.fail_strategy === "soft"
? SOFT_DESCRIPTION
: spec?.fail_hard_description;
<FormField
className="w-full"
state={timeoutServerError ? "error" : "idle"}
>
<FormField.Label>
Timeout{" "}
<Text font="main-ui-action" color="text-03">
(seconds)
</Text>
</FormField.Label>
<FormField.Control>
<div className="[&_input]:!font-main-ui-mono [&_input::placeholder]:!font-main-ui-mono [&_input]:![appearance:textfield] [&_input::-webkit-outer-spin-button]:!appearance-none [&_input::-webkit-inner-spin-button]:!appearance-none w-full">
<InputTypeIn
type="number"
value={form.timeout_seconds}
onChange={(e) => {
set("timeout_seconds", e.target.value);
if (timeoutServerError) setTimeoutServerError(false);
}}
onBlur={handleTimeoutBlur}
placeholder={
spec ? String(spec.default_timeout_seconds) : undefined
return (
<Form className="w-full overflow-visible">
<Modal.Header
icon={SvgShareWebhook}
title={
isEdit ? "Manage Hook Extension" : "Set Up Hook Extension"
}
variant={
isSubmitting
? "disabled"
: timeoutServerError
? "error"
: undefined
description={
isEdit
? undefined
: "Connect an external API endpoint to extend the hook point."
}
showClearButton={false}
rightSection={
spec?.default_timeout_seconds !== undefined &&
form.timeout_seconds !==
String(spec.default_timeout_seconds) ? (
<Button
prominence="tertiary"
size="xs"
icon={SvgRevert}
tooltip="Revert to Default"
onClick={() =>
set(
"timeout_seconds",
String(spec.default_timeout_seconds)
)
}
disabled={isSubmitting}
onClose={handleClose}
/>
<Modal.Body>
{/* Hook point section header */}
<ContentAction
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
title={hookPointDisplayName}
description={hookPointDescription}
rightChildren={
<div className="flex flex-col items-end gap-1">
<Content
sizePreset="secondary"
variant="body"
icon={SvgShareWebhook}
title="Hook Point"
prominence="muted"
widthVariant="fit"
/>
{docsUrl && (
<a
href={docsUrl}
target="_blank"
rel="noopener noreferrer"
className="underline leading-none"
>
<Text font="secondary-body" color="text-03">
Documentation
</Text>
</a>
)}
</div>
}
/>
<InputLayouts.Vertical name="name" title="Display Name">
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeInField
name="name"
placeholder="Name your extension at this hook point"
variant={isSubmitting ? "disabled" : undefined}
/>
) : undefined
}
/>
</div>
</FormField.Control>
{!timeoutServerError && (
<FormField.Description>
Maximum time Onyx will wait for the endpoint to respond before
applying the fail strategy. Must be greater than 0 and at most{" "}
{MAX_TIMEOUT_SECONDS} seconds.
</FormField.Description>
)}
<FormField.Message
messages={{
error: "Connection timed out. Try increasing the timeout.",
}}
/>
</FormField>
</div>
</InputLayouts.Vertical>
<FormField
className="w-full"
state={endpointFieldError ? "error" : "idle"}
>
<FormField.Label>External API Endpoint URL</FormField.Label>
<FormField.Control>
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeIn
value={form.endpoint_url}
onChange={(e) => {
set("endpoint_url", e.target.value);
if (endpointServerError) setEndpointServerError(null);
}}
onBlur={() => touch("endpoint_url")}
placeholder="https://your-api-endpoint.com"
variant={
isSubmitting
? "disabled"
: endpointFieldError
? "error"
: undefined
}
/>
</div>
</FormField.Control>
{!endpointFieldError && (
<FormField.Description>
Only connect to servers you trust. You are responsible for
actions taken and data shared with this connection.
</FormField.Description>
)}
<FormField.Message messages={{ error: endpointFieldError }} />
</FormField>
<InputLayouts.Vertical
name="fail_strategy"
title="Fail Strategy"
nonInteractive
subDescription={failStrategyDescription}
>
<InputSelect
value={values.fail_strategy}
onValueChange={(v) =>
setFieldValue("fail_strategy", v as HookFailStrategy)
}
disabled={isSubmitting}
>
<InputSelect.Trigger placeholder="Select strategy" />
<InputSelect.Content>
<InputSelect.Item value="soft">
Log Error and Continue
{spec?.default_fail_strategy === "soft" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
<InputSelect.Item value="hard">
Block Pipeline on Failure
{spec?.default_fail_strategy === "hard" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</InputLayouts.Vertical>
<FormField
className="w-full"
state={apiKeyFieldError ? "error" : "idle"}
>
<FormField.Label>API Key</FormField.Label>
<FormField.Control>
<PasswordInputTypeIn
value={form.api_key}
onChange={(e) => {
set("api_key", e.target.value);
if (apiKeyServerError) setApiKeyServerError(false);
if (isEdit) {
setApiKeyCleared(
e.target.value === "" && !!hook?.api_key_masked
);
}
}}
onBlur={() => touch("api_key")}
placeholder={
isEdit
? hook?.api_key_masked ?? "Leave blank to keep current key"
: undefined
}
disabled={isSubmitting}
error={!!apiKeyFieldError}
/>
</FormField.Control>
{!apiKeyFieldError && (
<FormField.Description>
Onyx will use this key to authenticate with your API endpoint.
</FormField.Description>
)}
<FormField.Message messages={{ error: apiKeyFieldError }} />
</FormField>
<TimeoutField spec={spec} />
{!isEdit && (isSubmitting || isConnected) && (
<Section
flexDirection="row"
alignItems="center"
justifyContent="start"
height="fit"
gap={1}
className="px-0.5"
>
<div className="p-0.5 shrink-0">
{isConnected ? (
<SvgCheckCircle
size={16}
className="text-status-success-05"
<InputLayouts.Vertical
name="endpoint_url"
title="External API Endpoint URL"
subDescription="Only connect to servers you trust. You are responsible for actions taken and data shared with this connection."
>
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeInField
name="endpoint_url"
placeholder="https://your-api-endpoint.com"
variant={isSubmitting ? "disabled" : undefined}
/>
</div>
</InputLayouts.Vertical>
<InputLayouts.Vertical
name="api_key"
title="API Key"
subDescription="Onyx will use this key to authenticate with your API endpoint."
>
<PasswordInputTypeInField
name="api_key"
placeholder={
isEdit
? hook?.api_key_masked ??
"Leave blank to keep current key"
: undefined
}
disabled={isSubmitting}
onChange={(e) => {
if (isEdit && hook?.api_key_masked) {
setApiKeyCleared(e.target.value === "");
}
}}
/>
</InputLayouts.Vertical>
{!isEdit && (isSubmitting || isConnected) && (
<Section
flexDirection="row"
alignItems="center"
justifyContent="start"
height="fit"
gap={1}
className="px-0.5"
>
<div className="p-0.5 shrink-0">
{isConnected ? (
<SvgCheckCircle
size={16}
className="text-status-success-05"
/>
) : (
<SvgLoader
size={16}
className="animate-spin text-text-03"
/>
)}
</div>
<Text font="secondary-body" color="text-03">
{isConnected
? "Connection valid."
: "Verifying connection…"}
</Text>
</Section>
)}
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Disabled disabled={isSubmitting}>
<Button prominence="secondary" onClick={handleClose}>
Cancel
</Button>
</Disabled>
}
submit={
<Disabled
disabled={
isSubmitting ||
!isValid ||
(!dirty && !apiKeyCleared && isEdit)
}
>
<Button
type="submit"
icon={
isSubmitting && !isEdit
? () => (
<SvgLoader
size={16}
className="animate-spin"
/>
)
: undefined
}
>
{isEdit ? "Save Changes" : "Connect"}
</Button>
</Disabled>
}
/>
) : (
<SvgLoader size={16} className="animate-spin text-text-03" />
)}
</div>
<Text font="secondary-body" color="text-03">
{isConnected ? "Connection valid." : "Verifying connection…"}
</Text>
</Section>
)}
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Disabled disabled={isSubmitting}>
<Button
prominence="secondary"
onClick={() => handleOpenChange(false)}
>
Cancel
</Button>
</Disabled>
}
submit={
<Disabled disabled={isSubmitting || !isValid || !hasChanges}>
<Button
onClick={handleSubmit}
icon={
isSubmitting && !isEdit
? () => <SvgLoader size={16} className="animate-spin" />
: undefined
}
>
{isEdit ? "Save Changes" : "Connect"}
</Button>
</Disabled>
}
/>
</Modal.Footer>
</Modal.Footer>
</Form>
);
}}
</Formik>
</Modal.Content>
</Modal>
);

View File

@@ -14,15 +14,16 @@ import type {
HookPointMeta,
HookResponse,
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
import { useModalClose } from "@/refresh-components/contexts/ModalContext";
interface HookLogsModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
hook: HookResponse;
spec: HookPointMeta | undefined;
}
// Section header: "Past Hour ————" or "Older ————"
//
// TODO(@raunakab): replace this with a proper, opalified `Separator` component (when it lands).
function SectionHeader({ label }: { label: string }) {
return (
<Section
@@ -69,12 +70,9 @@ function LogRow({ log }: { log: HookExecutionRecord }) {
);
}
export default function HookLogsModal({
open,
onOpenChange,
hook,
spec,
}: HookLogsModalProps) {
export default function HookLogsModal({ hook, spec }: HookLogsModalProps) {
const onClose = useModalClose();
const { recentErrors, olderErrors, isLoading, error } = useHookExecutionLogs(
hook.id,
10
@@ -99,7 +97,7 @@ export default function HookLogsModal({
}
return (
<Modal open={open} onOpenChange={onOpenChange}>
<Modal open onOpenChange={onClose}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={(props) => <SvgTextLines {...props} />}
@@ -107,7 +105,7 @@ export default function HookLogsModal({
description={`Hook: ${hook.name} • Hook Point: ${
spec?.display_name ?? hook.hook_point
}`}
onClose={() => onOpenChange(false)}
onClose={onClose}
/>
<Modal.Body>
{isLoading ? (

View File

@@ -1,9 +1,11 @@
"use client";
import { useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { noProp } from "@/lib/utils";
import { formatTimeOnly } from "@/lib/dateUtils";
import { Text } from "@opal/components";
import { Button, Text } from "@opal/components";
import { Content } from "@opal/layouts";
import LineItem from "@/refresh-components/buttons/LineItem";
import Popover from "@/refresh-components/Popover";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
@@ -22,6 +24,7 @@ import type {
HookPointMeta,
HookResponse,
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
import { cn } from "@opal/utils";
interface HookStatusPopoverProps {
hook: HookResponse;
@@ -34,7 +37,7 @@ export default function HookStatusPopover({
spec,
isBusy,
}: HookStatusPopoverProps) {
const [logsOpen, setLogsOpen] = useState(false);
const logsModal = useCreateModal();
const [open, setOpen] = useState(false);
// true = opened by click (stays until dismissed); false = opened by hover (closes after 1s)
const [clickOpened, setClickOpened] = useState(false);
@@ -113,39 +116,34 @@ export default function HookStatusPopover({
return (
<>
<HookLogsModal
open={logsOpen}
onOpenChange={setLogsOpen}
hook={hook}
spec={spec}
/>
<logsModal.Provider>
<HookLogsModal hook={hook} spec={spec} />
</logsModal.Provider>
<Popover open={open} onOpenChange={handleOpenChange}>
<Popover.Anchor asChild>
<div
<Button
prominence="tertiary"
rightIcon={({ className, ...props }) =>
hasRecentErrors ? (
<SvgAlertTriangle
{...props}
className={cn("text-status-warning-05", className)}
/>
) : (
<SvgCheckCircle
{...props}
className={cn("text-status-success-05", className)}
/>
)
}
onMouseEnter={handleTriggerMouseEnter}
onMouseLeave={handleTriggerMouseLeave}
onClick={handleTriggerClick}
className={cn(
"flex items-center gap-1 cursor-pointer rounded-xl p-2 transition-colors hover:bg-background-neutral-02",
isBusy && "opacity-50 pointer-events-none"
)}
onClick={noProp(handleTriggerClick)}
disabled={isBusy}
>
<Text font="main-ui-action" color="text-03">
Connected
</Text>
{hasRecentErrors ? (
<SvgAlertTriangle
size={16}
className="text-status-warning-05 shrink-0"
/>
) : (
<SvgCheckCircle
size={16}
className="text-status-success-05 shrink-0"
/>
)}
</div>
Connected
</Button>
</Popover.Anchor>
<Popover.Content
@@ -160,62 +158,36 @@ export default function HookStatusPopover({
alignItems="start"
height="fit"
width={hasRecentErrors ? 20 : 12.5}
padding={0.125}
gap={0.25}
>
{isLoading ? (
<Section justifyContent="center" height="fit" className="p-3">
<Section justifyContent="center">
<SimpleLoader />
</Section>
) : error ? (
<Section justifyContent="center" height="fit" className="p-3">
<Text font="secondary-body" color="text-03">
Failed to load logs.
</Text>
</Section>
<Text font="secondary-body" color="text-03">
Failed to load logs.
</Text>
) : hasRecentErrors ? (
// Errors state
<>
{/* Header: "N Errors" (≤3) or "Most Recent Errors" (>3) */}
<Section
flexDirection="row"
justifyContent="start"
alignItems="start"
gap={0.25}
padding={0.375}
height="fit"
className="rounded-lg"
>
<Section
justifyContent="center"
alignItems="center"
width={1.25}
height={1.25}
className="shrink-0"
>
<SvgXOctagon size={16} className="text-status-error-05" />
</Section>
<Section
flexDirection="column"
justifyContent="start"
alignItems="start"
width="fit"
height="fit"
gap={0}
className="px-0.5"
>
<Text font="main-ui-action" color="text-04">
{recentErrors.length <= 3
<div className="p-1">
<Content
sizePreset="secondary"
variant="section"
icon={SvgXOctagon}
title={
recentErrors.length <= 3
? `${recentErrors.length} ${
recentErrors.length === 1 ? "Error" : "Errors"
}`
: "Most Recent Errors"}
</Text>
<Text font="secondary-body" color="text-03">
in the past hour
</Text>
</Section>
</Section>
: "Most Recent Errors"
}
description="in the past hour"
/>
</div>
<Separator noPadding className="py-1" />
<Separator noPadding className="px-2" />
{/* Log rows — at most 3, timestamp first then error message */}
<Section
@@ -266,10 +238,10 @@ export default function HookStatusPopover({
<LineItem
muted
icon={SvgMaximize2}
onClick={() => {
onClick={noProp(() => {
handleOpenChange(false);
setLogsOpen(true);
}}
logsModal.toggle(true);
})}
>
View More Lines
</LineItem>
@@ -277,56 +249,26 @@ export default function HookStatusPopover({
) : (
// No errors state
<>
{/* No Error / in the past hour */}
<Section
flexDirection="row"
justifyContent="start"
alignItems="start"
gap={0.25}
padding={0.375}
height="fit"
className="rounded-lg"
>
<Section
justifyContent="center"
alignItems="center"
width={1.25}
height={1.25}
className="shrink-0"
>
<SvgCheckCircle
size={16}
className="text-status-success-05"
/>
</Section>
<Section
flexDirection="column"
justifyContent="start"
alignItems="start"
width="fit"
height="fit"
gap={0}
className="px-0.5"
>
<Text font="main-ui-action" color="text-04">
No Error
</Text>
<Text font="secondary-body" color="text-03">
in the past hour
</Text>
</Section>
</Section>
<div className="p-1">
<Content
sizePreset="secondary"
variant="section"
icon={SvgCheckCircle}
title="No Error"
description="in the past hour"
/>
</div>
<Separator noPadding className="py-1" />
<Separator noPadding className="px-2" />
{/* View Older Errors */}
<LineItem
muted
icon={SvgMaximize2}
onClick={() => {
onClick={noProp(() => {
handleOpenChange(false);
setLogsOpen(true);
}}
logsModal.toggle(true);
})}
>
View Older Errors
</LineItem>

View File

@@ -1,214 +0,0 @@
"use client";
import { useState } from "react";
import { useHookSpecs } from "@/ee/hooks/useHookSpecs";
import { useHooks } from "@/ee/hooks/useHooks";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { Button } from "@opal/components";
import { Content } from "@opal/layouts";
import InputSearch from "@/refresh-components/inputs/InputSearch";
import Card from "@/refresh-components/cards/Card";
import Text from "@/refresh-components/texts/Text";
import { SvgArrowExchange, SvgExternalLink } from "@opal/icons";
import HookFormModal from "@/ee/refresh-pages/admin/HooksPage/HookFormModal";
import ConnectedHookCard from "@/ee/refresh-pages/admin/HooksPage/ConnectedHookCard";
import { getHookPointIcon } from "@/ee/refresh-pages/admin/HooksPage/hookPointIcons";
import type {
HookPointMeta,
HookResponse,
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
import { markdown } from "@opal/utils";
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
export default function HooksContent() {
const [search, setSearch] = useState("");
const [connectSpec, setConnectSpec] = useState<HookPointMeta | null>(null);
const [editHook, setEditHook] = useState<HookResponse | null>(null);
const { specs, isLoading: specsLoading, error: specsError } = useHookSpecs();
const {
hooks,
isLoading: hooksLoading,
error: hooksError,
mutate,
} = useHooks();
if (specsLoading || hooksLoading) {
return <SimpleLoader />;
}
if (specsError || hooksError) {
return (
<Text text03 secondaryBody>
Failed to load{specsError ? " hook specifications" : " hooks"}. Please
refresh the page.
</Text>
);
}
const hooksByPoint: Record<string, HookResponse[]> = {};
for (const hook of hooks ?? []) {
(hooksByPoint[hook.hook_point] ??= []).push(hook);
}
const searchLower = search.toLowerCase();
// Connected hooks sorted alphabetically by hook name
const connectedHooks = (hooks ?? [])
.filter(
(hook) =>
!searchLower ||
hook.name.toLowerCase().includes(searchLower) ||
specs
?.find((s) => s.hook_point === hook.hook_point)
?.display_name.toLowerCase()
.includes(searchLower)
)
.sort((a, b) => a.name.localeCompare(b.name));
// Unconnected hook point specs sorted alphabetically
const unconnectedSpecs = (specs ?? [])
.filter(
(spec) =>
(hooksByPoint[spec.hook_point]?.length ?? 0) === 0 &&
(!searchLower ||
spec.display_name.toLowerCase().includes(searchLower) ||
spec.description.toLowerCase().includes(searchLower))
)
.sort((a, b) => a.display_name.localeCompare(b.display_name));
function handleHookSuccess(updated: HookResponse) {
mutate((prev) => {
if (!prev) return [updated];
const idx = prev.findIndex((h) => h.id === updated.id);
if (idx >= 0) {
const next = [...prev];
next[idx] = updated;
return next;
}
return [...prev, updated];
});
}
function handleHookDeleted(id: number) {
mutate((prev) => prev?.filter((h) => h.id !== id));
}
const connectSpec_ =
connectSpec ??
(editHook
? specs?.find((s) => s.hook_point === editHook.hook_point)
: undefined);
return (
<>
<div className="flex flex-col gap-6">
<InputSearch
placeholder="Search hooks..."
value={search}
onChange={(e) => setSearch(e.target.value)}
/>
<div className="flex flex-col gap-2">
{connectedHooks.length === 0 && unconnectedSpecs.length === 0 ? (
<Text text03 secondaryBody>
{search
? "No hooks match your search."
: "No hook points are available."}
</Text>
) : (
<>
{connectedHooks.map((hook) => {
const spec = specs?.find(
(s) => s.hook_point === hook.hook_point
);
return (
<ConnectedHookCard
key={hook.id}
hook={hook}
spec={spec}
onEdit={() => setEditHook(hook)}
onDeleted={() => handleHookDeleted(hook.id)}
onToggled={handleHookSuccess}
/>
);
})}
{unconnectedSpecs.map((spec) => {
const UnconnectedIcon = getHookPointIcon(spec.hook_point);
return (
<Card
key={spec.hook_point}
variant="secondary"
padding={0.5}
gap={0}
className="hover:border-border-02"
>
<div className="w-full flex flex-row">
<div className="flex-1 p-2">
<Content
sizePreset="main-ui"
variant="section"
icon={UnconnectedIcon}
title={spec.display_name}
description={spec.description}
/>
{spec.docs_url && (
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="pl-6 flex items-center gap-1"
>
<span className="underline font-secondary-body text-text-03">
Documentation
</span>
<SvgExternalLink size={12} className="shrink-0" />
</a>
)}
</div>
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={() => setConnectSpec(spec)}
>
Connect
</Button>
</div>
</Card>
);
})}
</>
)}
</div>
</div>
{/* Create modal */}
<HookFormModal
key={connectSpec?.hook_point ?? "create"}
open={!!connectSpec}
onOpenChange={(open) => {
if (!open) setConnectSpec(null);
}}
spec={connectSpec ?? undefined}
onSuccess={handleHookSuccess}
/>
{/* Edit modal */}
<HookFormModal
key={editHook?.id ?? "edit"}
open={!!editHook}
onOpenChange={(open) => {
if (!open) setEditHook(null);
}}
hook={editHook ?? undefined}
spec={connectSpec_ ?? undefined}
onSuccess={handleHookSuccess}
/>
</>
);
}

View File

@@ -1,13 +0,0 @@
import { SvgBubbleText, SvgFileBroadcast, SvgHookNodes } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
document_ingestion: SvgFileBroadcast,
query_processing: SvgBubbleText,
};
function getHookPointIcon(hookPoint: string): IconFunctionComponent {
return HOOK_POINT_ICONS[hookPoint] ?? SvgHookNodes;
}
export { HOOK_POINT_ICONS, getHookPointIcon };

View File

@@ -1,22 +1,509 @@
"use client";
import { useEffect } from "react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useRouter } from "next/navigation";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useHookSpecs } from "@/ee/hooks/useHookSpecs";
import { useHooks } from "@/ee/hooks/useHooks";
import useFilter from "@/hooks/useFilter";
import { toast } from "@/hooks/useToast";
import {
useCreateModal,
useModalClose,
} from "@/refresh-components/contexts/ModalContext";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import HooksContent from "./HooksContent";
import { Button, SelectCard, Text } from "@opal/components";
import { Disabled, Hoverable } from "@opal/core";
import { markdown } from "@opal/utils";
import { Content, IllustrationContent } from "@opal/layouts";
import Modal from "@/refresh-components/Modal";
import {
SvgArrowExchange,
SvgBubbleText,
SvgExternalLink,
SvgFileBroadcast,
SvgShareWebhook,
SvgPlug,
SvgRefreshCw,
SvgSettings,
SvgTrash,
SvgUnplug,
} from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
import { SvgNoResult, SvgEmpty } from "@opal/illustrations";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import HookFormModal from "@/ee/refresh-pages/admin/HooksPage/HookFormModal";
import HookStatusPopover from "@/ee/refresh-pages/admin/HooksPage/HookStatusPopover";
import {
activateHook,
deactivateHook,
deleteHook,
validateHook,
} from "@/ee/refresh-pages/admin/HooksPage/svc";
import type {
HookPointMeta,
HookResponse,
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
import { noProp } from "@/lib/utils";
const route = ADMIN_ROUTES.HOOKS;
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
document_ingestion: SvgFileBroadcast,
query_processing: SvgBubbleText,
};
function getHookPointIcon(hookPoint: string): IconFunctionComponent {
return HOOK_POINT_ICONS[hookPoint] ?? SvgShareWebhook;
}
// ---------------------------------------------------------------------------
// Disconnect confirmation modal
// ---------------------------------------------------------------------------
interface DisconnectConfirmModalProps {
hook: HookResponse;
onDisconnect: () => void;
onDisconnectAndDelete: () => void;
}
function DisconnectConfirmModal({
hook,
onDisconnect,
onDisconnectAndDelete,
}: DisconnectConfirmModalProps) {
const onClose = useModalClose();
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="md" height="fit">
<Modal.Header
// TODO(@raunakab): replace the colour of this SVG with red.
icon={SvgUnplug}
title={markdown(`Disconnect *${hook.name}*`)}
onClose={onClose}
/>
<Modal.Body>
<div className="flex flex-col gap-2">
<Text font="main-ui-body" color="text-03">
{markdown(
`Onyx will stop calling this endpoint for hook ***${hook.name}***. In-flight requests will continue to run. The external endpoint may still retain data previously sent to it. You can reconnect this hook later if needed.`
)}
</Text>
<Text font="main-ui-body" color="text-03">
You can also delete this hook. Deletion cannot be undone.
</Text>
</div>
</Modal.Body>
<Modal.Footer>
<Button prominence="secondary" onClick={onClose}>
Cancel
</Button>
<Button
variant="danger"
prominence="secondary"
onClick={onDisconnectAndDelete}
>
Disconnect &amp; Delete
</Button>
<Button variant="danger" prominence="primary" onClick={onDisconnect}>
Disconnect
</Button>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}
// ---------------------------------------------------------------------------
// Delete confirmation modal
// ---------------------------------------------------------------------------
interface DeleteConfirmModalProps {
hook: HookResponse;
onDelete: () => void;
}
function DeleteConfirmModal({ hook, onDelete }: DeleteConfirmModalProps) {
const onClose = useModalClose();
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="md" height="fit">
<Modal.Header
// TODO(@raunakab): replace the colour of this SVG with red.
icon={SvgTrash}
title={`Delete ${hook.name}`}
onClose={onClose}
/>
<Modal.Body>
<div className="flex flex-col gap-2">
<Text font="main-ui-body" color="text-03">
{markdown(
`Hook ***${hook.name}*** will be permanently removed from this hook point. The external endpoint may still retain data previously sent to it.`
)}
</Text>
<Text font="main-ui-body" color="text-03">
Deletion cannot be undone.
</Text>
</div>
</Modal.Body>
<Modal.Footer>
<Button prominence="secondary" onClick={onClose}>
Cancel
</Button>
<Button variant="danger" prominence="primary" onClick={onDelete}>
Delete
</Button>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}
// ---------------------------------------------------------------------------
// Unconnected hook card
// ---------------------------------------------------------------------------
interface UnconnectedHookCardProps {
spec: HookPointMeta;
onConnect: () => void;
}
function UnconnectedHookCard({ spec, onConnect }: UnconnectedHookCardProps) {
const Icon = getHookPointIcon(spec.hook_point);
return (
<SelectCard state="empty" padding="sm" rounding="lg" onClick={onConnect}>
<div className="w-full flex flex-row">
<div className="flex-1 p-2">
<Content
sizePreset="main-ui"
variant="section"
icon={Icon}
title={spec.display_name}
description={spec.description}
/>
{spec.docs_url && (
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="ml-6 flex items-center gap-1 w-min"
>
<span className="underline font-secondary-body text-text-03">
Documentation
</span>
<SvgExternalLink size={12} className="shrink-0" />
</a>
)}
</div>
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={noProp(onConnect)}
>
Connect
</Button>
</div>
</SelectCard>
);
}
// ---------------------------------------------------------------------------
// Connected hook card
// ---------------------------------------------------------------------------
interface ConnectedHookCardProps {
hook: HookResponse;
spec: HookPointMeta | undefined;
onEdit: () => void;
onDeleted: () => void;
onToggled: (updated: HookResponse) => void;
}
function ConnectedHookCard({
hook,
spec,
onEdit,
onDeleted,
onToggled,
}: ConnectedHookCardProps) {
const [isBusy, setIsBusy] = useState(false);
const disconnectModal = useCreateModal();
const deleteModal = useCreateModal();
async function handleDelete() {
deleteModal.toggle(false);
setIsBusy(true);
try {
await deleteHook(hook.id);
onDeleted();
} catch (err) {
console.error("Failed to delete hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to delete hook."
);
} finally {
setIsBusy(false);
}
}
async function handleActivate() {
setIsBusy(true);
try {
const updated = await activateHook(hook.id);
onToggled(updated);
} catch (err) {
console.error("Failed to reconnect hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to reconnect hook."
);
} finally {
setIsBusy(false);
}
}
async function handleDeactivate() {
disconnectModal.toggle(false);
setIsBusy(true);
try {
const updated = await deactivateHook(hook.id);
onToggled(updated);
} catch (err) {
console.error("Failed to deactivate hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to deactivate hook."
);
} finally {
setIsBusy(false);
}
}
async function handleDisconnectAndDelete() {
disconnectModal.toggle(false);
setIsBusy(true);
try {
const deactivated = await deactivateHook(hook.id);
onToggled(deactivated);
await deleteHook(hook.id);
onDeleted();
} catch (err) {
console.error("Failed to disconnect hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to disconnect hook."
);
} finally {
setIsBusy(false);
}
}
async function handleValidate() {
setIsBusy(true);
try {
const result = await validateHook(hook.id);
if (result.status === "passed") {
toast.success("Hook validated successfully.");
} else {
toast.error(
result.error_message ?? `Validation failed: ${result.status}`
);
}
} catch (err) {
console.error("Failed to validate hook:", err);
toast.error(
err instanceof Error ? err.message : "Failed to validate hook."
);
} finally {
setIsBusy(false);
}
}
const HookIcon = getHookPointIcon(hook.hook_point);
return (
<>
<disconnectModal.Provider>
<DisconnectConfirmModal
hook={hook}
onDisconnect={handleDeactivate}
onDisconnectAndDelete={handleDisconnectAndDelete}
/>
</disconnectModal.Provider>
<deleteModal.Provider>
<DeleteConfirmModal hook={hook} onDelete={handleDelete} />
</deleteModal.Provider>
<Hoverable.Root group="connected-hook-card">
<SelectCard state="filled" padding="sm" rounding="lg" onClick={onEdit}>
<div className="w-full flex flex-row">
<div className="flex-1 p-2">
<Content
sizePreset="main-ui"
variant="section"
icon={HookIcon}
title={
!hook.is_active ? markdown(`~~${hook.name}~~`) : hook.name
}
description={`Hook Point: ${
spec?.display_name ?? hook.hook_point
}`}
/>
{spec?.docs_url && (
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="ml-6 flex items-center gap-1 w-min"
>
<span className="underline font-secondary-body text-text-03">
Documentation
</span>
<SvgExternalLink size={12} className="shrink-0" />
</a>
)}
</div>
<div className="flex flex-col items-end shrink-0">
<div className="flex items-center gap-1">
{hook.is_active ? (
<HookStatusPopover hook={hook} spec={spec} isBusy={isBusy} />
) : (
<Button
prominence="tertiary"
rightIcon={SvgPlug}
onClick={noProp(handleActivate)}
disabled={isBusy}
>
Reconnect
</Button>
)}
</div>
<Disabled disabled={isBusy}>
<div className="flex items-center pb-1 px-1 gap-1">
{hook.is_active ? (
<>
<Hoverable.Item
group="connected-hook-card"
variant="opacity-on-hover"
>
<Button
prominence="tertiary"
size="md"
icon={SvgUnplug}
onClick={noProp(() => disconnectModal.toggle(true))}
tooltip="Disconnect Hook"
aria-label="Deactivate hook"
/>
</Hoverable.Item>
<Button
prominence="tertiary"
size="md"
icon={SvgRefreshCw}
onClick={noProp(handleValidate)}
tooltip="Test Connection"
aria-label="Re-validate hook"
/>
</>
) : (
<Button
prominence="tertiary"
size="md"
icon={SvgTrash}
onClick={noProp(() => deleteModal.toggle(true))}
tooltip="Delete"
aria-label="Delete hook"
/>
)}
<Button
prominence="tertiary"
size="md"
icon={SvgSettings}
onClick={noProp(onEdit)}
tooltip="Manage"
aria-label="Configure hook"
/>
</div>
</Disabled>
</div>
</div>
</SelectCard>
</Hoverable.Root>
</>
);
}
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
export default function HooksPage() {
const router = useRouter();
const { settings, settingsLoading } = useSettingsContext();
const isEE = usePaidEnterpriseFeaturesEnabled();
const [connectSpec, setConnectSpec] = useState<HookPointMeta | null>(null);
const [editHook, setEditHook] = useState<HookResponse | null>(null);
const { specs, isLoading: specsLoading, error: specsError } = useHookSpecs();
const {
hooks,
isLoading: hooksLoading,
error: hooksError,
mutate,
} = useHooks();
const hookExtractor = useCallback(
(hook: HookResponse) =>
`${hook.name} ${
specs?.find((s: HookPointMeta) => s.hook_point === hook.hook_point)
?.display_name ?? ""
}`,
[specs]
);
const sortedHooks = useMemo(
() => [...(hooks ?? [])].sort((a, b) => a.name.localeCompare(b.name)),
[hooks]
);
const {
query: search,
setQuery: setSearch,
filtered: connectedHooks,
} = useFilter(sortedHooks, hookExtractor);
const hooksByPoint = useMemo(() => {
const map: Record<string, HookResponse[]> = {};
for (const hook of hooks ?? []) {
(map[hook.hook_point] ??= []).push(hook);
}
return map;
}, [hooks]);
const unconnectedSpecs = useMemo(() => {
const searchLower = search.toLowerCase();
return (specs ?? [])
.filter(
(spec: HookPointMeta) =>
(hooksByPoint[spec.hook_point]?.length ?? 0) === 0 &&
(!searchLower ||
spec.display_name.toLowerCase().includes(searchLower) ||
spec.description.toLowerCase().includes(searchLower))
)
.sort((a: HookPointMeta, b: HookPointMeta) =>
a.display_name.localeCompare(b.display_name)
);
}, [specs, hooksByPoint, search]);
useEffect(() => {
if (settingsLoading) return;
if (!isEE) {
@@ -32,17 +519,132 @@ export default function HooksPage() {
return <SimpleLoader />;
}
const isLoading = specsLoading || hooksLoading;
function handleHookSuccess(updated: HookResponse) {
mutate((prev: HookResponse[] | undefined) => {
if (!prev) return [updated];
const idx = prev.findIndex((h: HookResponse) => h.id === updated.id);
if (idx >= 0) {
const next = [...prev];
next[idx] = updated;
return next;
}
return [...prev, updated];
});
}
function handleHookDeleted(id: number) {
mutate(
(prev: HookResponse[] | undefined) =>
prev?.filter((h: HookResponse) => h.id !== id)
);
}
const connectSpec_ =
connectSpec ??
(editHook
? specs?.find((s: HookPointMeta) => s.hook_point === editHook.hook_point)
: undefined);
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
description="Extend Onyx pipelines by registering external API endpoints as callbacks at predefined hook points."
separator
/>
<SettingsLayouts.Body>
<HooksContent />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
<>
{/* Create modal */}
{!!connectSpec && (
<HookFormModal
key={connectSpec?.hook_point ?? "create"}
onOpenChange={(open: boolean) => {
if (!open) setConnectSpec(null);
}}
spec={connectSpec ?? undefined}
onSuccess={handleHookSuccess}
/>
)}
{/* Edit modal */}
{!!editHook && (
<HookFormModal
key={editHook?.id ?? "edit"}
onOpenChange={(open: boolean) => {
if (!open) setEditHook(null);
}}
hook={editHook ?? undefined}
spec={connectSpec_ ?? undefined}
onSuccess={handleHookSuccess}
/>
)}
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
description="Extend Onyx pipelines by registering external API endpoints as callbacks at predefined hook points."
separator
/>
<SettingsLayouts.Body>
{isLoading ? (
<SimpleLoader />
) : specsError || hooksError ? (
<Text font="secondary-body" color="text-03">
{`Failed to load${
specsError ? " hook specifications" : " hooks"
}. Please refresh the page.`}
</Text>
) : (
<div className="flex flex-col gap-3 h-full">
<div className="pb-3">
<InputTypeIn
placeholder="Search hooks..."
value={search}
variant="internal"
leftSearchIcon
onChange={(e) => setSearch(e.target.value)}
/>
</div>
{connectedHooks.length === 0 && unconnectedSpecs.length === 0 ? (
<div>
<IllustrationContent
title={
search ? "No results found" : "No hook points available"
}
description={
search ? "Try using a different search term." : undefined
}
illustration={search ? SvgNoResult : SvgEmpty}
/>
</div>
) : (
<div className="flex flex-col gap-2">
{connectedHooks.map((hook) => {
const spec = specs?.find(
(s: HookPointMeta) => s.hook_point === hook.hook_point
);
return (
<ConnectedHookCard
key={hook.id}
hook={hook}
spec={spec}
onEdit={() => setEditHook(hook)}
onDeleted={() => handleHookDeleted(hook.id)}
onToggled={handleHookSuccess}
/>
);
})}
{unconnectedSpecs.map((spec: HookPointMeta) => (
<UnconnectedHookCard
key={spec.hook_point}
spec={spec}
onConnect={() => setConnectSpec(spec)}
/>
))}
</div>
)}
</div>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
</>
);
}

View File

@@ -901,6 +901,11 @@ export default function useChatController({
});
}
}
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
// catch block replaces the thinking placeholder with an error message.
if (stack.error) {
throw new Error(stack.error);
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;

View File

@@ -0,0 +1,232 @@
import { renderHook, act } from "@tests/setup/test-utils";
import useMultiModelChat from "@/hooks/useMultiModelChat";
import { LlmManager } from "@/lib/hooks";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
// Mock buildLlmOptions — hook uses it internally for initialization.
// Tests here focus on CRUD operations, not the initialization side-effect.
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
buildLlmOptions: jest.fn(() => []),
}));
const makeLlmManager = (): LlmManager =>
({
llmProviders: [],
currentLlm: { modelName: null, provider: null },
isLoadingProviders: false,
}) as unknown as LlmManager;
const makeModel = (provider: string, modelName: string): SelectedModel => ({
name: provider,
provider,
modelName,
displayName: `${provider}/${modelName}`,
});
const GPT4 = makeModel("openai", "gpt-4");
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
const GEMINI = makeModel("google", "gemini-pro");
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
// ---------------------------------------------------------------------------
// addModel
// ---------------------------------------------------------------------------
describe("addModel", () => {
it("adds a model to an empty selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
it("does not add a duplicate model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(GPT4); // duplicate
});
expect(result.current.selectedModels).toHaveLength(1);
});
it("enforces MAX_MODELS (3) cap", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
result.current.addModel(GEMINI);
result.current.addModel(GPT4_TURBO); // should be ignored
});
expect(result.current.selectedModels).toHaveLength(3);
});
});
// ---------------------------------------------------------------------------
// removeModel
// ---------------------------------------------------------------------------
describe("removeModel", () => {
it("removes a model by index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.removeModel(0); // remove GPT4
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
});
it("handles out-of-range index gracefully", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
act(() => {
result.current.removeModel(99); // no-op
});
expect(result.current.selectedModels).toHaveLength(1);
});
});
// ---------------------------------------------------------------------------
// replaceModel
// ---------------------------------------------------------------------------
describe("replaceModel", () => {
it("replaces the model at the given index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, GEMINI);
});
expect(result.current.selectedModels[0]).toEqual(GEMINI);
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
});
it("does not replace with a model already selected at another index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
});
// Should be a no-op — GPT4 stays at index 0
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
});
// ---------------------------------------------------------------------------
// isMultiModelActive
// ---------------------------------------------------------------------------
describe("isMultiModelActive", () => {
it("is false with zero models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.isMultiModelActive).toBe(false);
});
it("is false with exactly one model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.isMultiModelActive).toBe(false);
});
it("is true with two or more models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
expect(result.current.isMultiModelActive).toBe(true);
});
});
// ---------------------------------------------------------------------------
// buildLlmOverrides
// ---------------------------------------------------------------------------
describe("buildLlmOverrides", () => {
it("returns empty array when no models selected", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.buildLlmOverrides()).toEqual([]);
});
it("maps selectedModels to LLMOverride format", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
const overrides = result.current.buildLlmOverrides();
expect(overrides).toHaveLength(2);
expect(overrides[0]).toEqual({
model_provider: "openai",
model_version: "gpt-4",
display_name: "openai/gpt-4",
});
expect(overrides[1]).toEqual({
model_provider: "anthropic",
model_version: "claude-opus-4-6",
display_name: "anthropic/claude-opus-4-6",
});
});
});
// ---------------------------------------------------------------------------
// clearModels
// ---------------------------------------------------------------------------
describe("clearModels", () => {
it("empties the selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.clearModels();
});
expect(result.current.selectedModels).toHaveLength(0);
expect(result.current.isMultiModelActive).toBe(false);
});
});

View File

@@ -0,0 +1,192 @@
"use client";
import { useState, useCallback, useEffect, useMemo } from "react";
import {
MAX_MODELS,
SelectedModel,
} from "@/refresh-components/popovers/ModelSelector";
import { LLMOverride } from "@/app/app/services/lib";
import { LlmManager } from "@/lib/hooks";
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
export interface UseMultiModelChatReturn {
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
/** Whether multi-model mode is active (>1 model selected). */
isMultiModelActive: boolean;
/** Add a model to the selection. */
addModel: (model: SelectedModel) => void;
/** Remove a model by index. */
removeModel: (index: number) => void;
/** Replace a model at a specific index with a new one. */
replaceModel: (index: number, model: SelectedModel) => void;
/** Clear all selected models. */
clearModels: () => void;
/** Build the LLMOverride[] array from selectedModels. */
buildLlmOverrides: () => LLMOverride[];
/**
* Restore multi-model selection from model version strings (e.g. from chat history).
* Matches against available llmOptions to reconstruct full SelectedModel objects.
*/
restoreFromModelNames: (modelNames: string[]) => void;
/**
* Switch to a single model by name (after user picks a preferred response).
* Matches against llmOptions to find the full SelectedModel.
*/
selectSingleModel: (modelName: string) => void;
}
export default function useMultiModelChat(
llmManager: LlmManager
): UseMultiModelChatReturn {
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
const [defaultInitialized, setDefaultInitialized] = useState(false);
// Initialize with the default model from llmManager once providers load
const llmOptions = useMemo(
() =>
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
[llmManager.llmProviders]
);
useEffect(() => {
if (defaultInitialized) return;
if (llmOptions.length === 0) return;
const { currentLlm } = llmManager;
// Don't initialize if currentLlm hasn't loaded yet
if (!currentLlm.modelName) return;
const match = llmOptions.find(
(opt) =>
opt.provider === currentLlm.provider &&
opt.modelName === currentLlm.modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
setDefaultInitialized(true);
}
}, [llmOptions, llmManager.currentLlm, defaultInitialized]);
const isMultiModelActive = selectedModels.length > 1;
const addModel = useCallback((model: SelectedModel) => {
setSelectedModels((prev) => {
if (prev.length >= MAX_MODELS) return prev;
if (
prev.some(
(m) =>
m.provider === model.provider && m.modelName === model.modelName
)
) {
return prev;
}
return [...prev, model];
});
}, []);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const replaceModel = useCallback((index: number, model: SelectedModel) => {
setSelectedModels((prev) => {
// Don't replace with a model that's already selected elsewhere
if (
prev.some(
(m, i) =>
i !== index &&
m.provider === model.provider &&
m.modelName === model.modelName
)
) {
return prev;
}
const next = [...prev];
next[index] = model;
return next;
});
}, []);
const clearModels = useCallback(() => {
setSelectedModels([]);
}, []);
const restoreFromModelNames = useCallback(
(modelNames: string[]) => {
if (modelNames.length < 2 || llmOptions.length === 0) return;
const restored: SelectedModel[] = [];
for (const name of modelNames) {
// Try matching by modelName (raw version string like "claude-opus-4-6")
// or by displayName (friendly name like "Claude Opus 4.6")
const match = llmOptions.find(
(opt) =>
opt.modelName === name ||
opt.displayName === name ||
opt.name === name
);
if (match) {
restored.push({
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
});
}
}
if (restored.length >= 2) {
setSelectedModels(restored.slice(0, MAX_MODELS));
setDefaultInitialized(true);
}
},
[llmOptions]
);
const selectSingleModel = useCallback(
(modelName: string) => {
if (llmOptions.length === 0) return;
const match = llmOptions.find(
(opt) =>
opt.modelName === modelName ||
opt.displayName === modelName ||
opt.name === modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
}
},
[llmOptions]
);
const buildLlmOverrides = useCallback((): LLMOverride[] => {
return selectedModels.map((m) => ({
model_provider: m.provider,
model_version: m.modelName,
display_name: m.displayName,
}));
}, [selectedModels]);
return {
selectedModels,
isMultiModelActive,
addModel,
removeModel,
replaceModel,
clearModels,
buildLlmOverrides,
restoreFromModelNames,
selectSingleModel,
};
}

View File

@@ -12,6 +12,7 @@ export enum LLMProviderName {
OPENROUTER = "openrouter",
VERTEX_AI = "vertex_ai",
BEDROCK = "bedrock",
LITELLM = "litellm",
LITELLM_PROXY = "litellm_proxy",
BIFROST = "bifrost",
CUSTOM = "custom",

View File

@@ -4,7 +4,7 @@ import {
SvgActivity,
SvgArrowExchange,
SvgAudio,
SvgHookNodes,
SvgShareWebhook,
SvgBarChart,
SvgBookOpen,
SvgBubbleText,
@@ -230,7 +230,7 @@ export const ADMIN_ROUTES = {
},
HOOKS: {
path: "/admin/hooks",
icon: SvgHookNodes,
icon: SvgShareWebhook,
title: "Hook Extensions",
sidebarLabel: "Hook Extensions",
},

View File

@@ -5,7 +5,6 @@ import {
SvgOpenai,
SvgClaude,
SvgOllama,
SvgCloud,
SvgAws,
SvgOpenrouter,
SvgServer,
@@ -22,7 +21,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
[LLMProviderName.VERTEX_AI]: SvgGemini,
[LLMProviderName.BEDROCK]: SvgAws,
[LLMProviderName.AZURE]: SvgAzure,
litellm: SvgLitellm,
[LLMProviderName.LITELLM]: SvgLitellm,
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
@@ -39,7 +38,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
[LLMProviderName.VERTEX_AI]: "Gemini",
[LLMProviderName.BEDROCK]: "Amazon Bedrock",
[LLMProviderName.AZURE]: "Azure OpenAI",
litellm: "LiteLLM",
[LLMProviderName.LITELLM]: "LiteLLM",
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
[LLMProviderName.OPENROUTER]: "OpenRouter",
@@ -56,7 +55,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
[LLMProviderName.VERTEX_AI]: "Google Cloud Vertex AI",
[LLMProviderName.BEDROCK]: "AWS",
[LLMProviderName.AZURE]: "Microsoft Azure",
litellm: "LiteLLM",
[LLMProviderName.LITELLM]: "LiteLLM",
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
[LLMProviderName.OPENROUTER]: "OpenRouter",

View File

@@ -0,0 +1,159 @@
"use client";
import { useState, useMemo, useRef } from "react";
import { PopoverMenu } from "@/refresh-components/Popover";
import Separator from "@/refresh-components/Separator";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { Text } from "@opal/components";
import { SvgCheck } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { LLMOption } from "./interfaces";
import { buildLlmOptions, groupLlmOptions } from "./LLMPopover";
import LineItem from "@/refresh-components/buttons/LineItem";
import { LLMProviderDescriptor } from "@/interfaces/llm";
export interface ModelListContentProps {
llmProviders: LLMProviderDescriptor[] | undefined;
currentModelName?: string;
requiresImageInput?: boolean;
onSelect: (option: LLMOption) => void;
isSelected: (option: LLMOption) => boolean;
isDisabled?: (option: LLMOption) => boolean;
scrollContainerRef?: React.RefObject<HTMLDivElement | null>;
isLoading?: boolean;
footer?: React.ReactNode;
}
export default function ModelListContent({
llmProviders,
currentModelName,
requiresImageInput,
onSelect,
isSelected,
isDisabled,
scrollContainerRef: externalScrollRef,
isLoading,
footer,
}: ModelListContentProps) {
const [searchQuery, setSearchQuery] = useState("");
const internalScrollRef = useRef<HTMLDivElement>(null);
const scrollContainerRef = externalScrollRef ?? internalScrollRef;
const llmOptions = useMemo(
() => buildLlmOptions(llmProviders, currentModelName),
[llmProviders, currentModelName]
);
const filteredOptions = useMemo(() => {
let result = llmOptions;
if (requiresImageInput) {
result = result.filter((opt) => opt.supportsImageInput);
}
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase();
result = result.filter(
(opt) =>
opt.displayName.toLowerCase().includes(query) ||
opt.modelName.toLowerCase().includes(query) ||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
);
}
return result;
}, [llmOptions, searchQuery, requiresImageInput]);
const groupedOptions = useMemo(
() => groupLlmOptions(filteredOptions),
[filteredOptions]
);
const renderModelItem = (option: LLMOption) => {
const selected = isSelected(option);
const disabled = isDisabled?.(option) ?? false;
const capabilities: string[] = [];
if (option.supportsReasoning) capabilities.push("Reasoning");
if (option.supportsImageInput) capabilities.push("Vision");
const description =
capabilities.length > 0 ? capabilities.join(", ") : undefined;
return (
<LineItem
key={`${option.provider}:${option.modelName}`}
selected={selected}
disabled={disabled}
description={description}
onClick={() => onSelect(option)}
rightChildren={
selected ? (
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
) : null
}
>
{option.displayName}
</LineItem>
);
};
return (
<Section gap={0.5}>
<InputTypeIn
leftSearchIcon
variant="internal"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
placeholder="Search models..."
/>
<PopoverMenu scrollContainerRef={scrollContainerRef}>
{isLoading
? [
<div key="loading" className="py-3 px-2">
<Text font="secondary-body" color="text-03">
Loading models...
</Text>
</div>,
]
: groupedOptions.length === 0
? [
<div key="empty" className="py-3 px-2">
<Text font="secondary-body" color="text-03">
No models found
</Text>
</div>,
]
: groupedOptions.map((group) => (
<div key={group.key}>
{groupedOptions.length > 1 && (
<Section
flexDirection="row"
gap={0.25}
padding={0}
height="auto"
alignItems="center"
justifyContent="start"
className="px-2 pt-2 pb-1"
>
<div className="flex items-center gap-1 shrink-0">
<group.Icon size={16} />
<Text font="secondary-body" color="text-03" nowrap>
{group.displayName}
</Text>
</div>
<Separator noPadding className="flex-1" />
</Section>
)}
<Section
gap={0.25}
alignItems="stretch"
justifyContent="start"
>
{group.options.map(renderModelItem)}
</Section>
</div>
))}
</PopoverMenu>
{footer}
</Section>
);
}

View File

@@ -0,0 +1,230 @@
"use client";
import { useState, useMemo, useRef } from "react";
import Popover from "@/refresh-components/Popover";
import { LlmManager } from "@/lib/hooks";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import { Button, SelectButton, OpenButton } from "@opal/components";
import { SvgPlusCircle, SvgX } from "@opal/icons";
import { LLMOption } from "@/refresh-components/popovers/interfaces";
import ModelListContent from "@/refresh-components/popovers/ModelListContent";
import Separator from "@/refresh-components/Separator";
export const MAX_MODELS = 3;
export interface SelectedModel {
name: string;
provider: string;
modelName: string;
displayName: string;
}
export interface ModelSelectorProps {
llmManager: LlmManager;
selectedModels: SelectedModel[];
onAdd: (model: SelectedModel) => void;
onRemove: (index: number) => void;
onReplace: (index: number, model: SelectedModel) => void;
}
function modelKey(provider: string, modelName: string): string {
return `${provider}:${modelName}`;
}
export default function ModelSelector({
llmManager,
selectedModels,
onAdd,
onRemove,
onReplace,
}: ModelSelectorProps) {
const [open, setOpen] = useState(false);
// null = add mode (via + button), number = replace mode (via pill click)
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
// Virtual anchor ref — points to the clicked pill so the popover positions above it
const anchorRef = useRef<HTMLElement | null>(null);
const isMultiModel = selectedModels.length > 1;
const atMax = selectedModels.length >= MAX_MODELS;
const selectedKeys = useMemo(
() => new Set(selectedModels.map((m) => modelKey(m.provider, m.modelName))),
[selectedModels]
);
const otherSelectedKeys = useMemo(() => {
if (replacingIndex === null) return new Set<string>();
return new Set(
selectedModels
.filter((_, i) => i !== replacingIndex)
.map((m) => modelKey(m.provider, m.modelName))
);
}, [selectedModels, replacingIndex]);
const replacingKey =
replacingIndex !== null
? (() => {
const m = selectedModels[replacingIndex];
return m ? modelKey(m.provider, m.modelName) : null;
})()
: null;
const isSelected = (option: LLMOption) => {
const key = modelKey(option.provider, option.modelName);
if (replacingIndex !== null) return key === replacingKey;
return selectedKeys.has(key);
};
const isDisabled = (option: LLMOption) => {
const key = modelKey(option.provider, option.modelName);
if (replacingIndex !== null) return otherSelectedKeys.has(key);
return !selectedKeys.has(key) && atMax;
};
const handleSelect = (option: LLMOption) => {
const model: SelectedModel = {
name: option.name,
provider: option.provider,
modelName: option.modelName,
displayName: option.displayName,
};
if (replacingIndex !== null) {
onReplace(replacingIndex, model);
setOpen(false);
setReplacingIndex(null);
return;
}
const key = modelKey(option.provider, option.modelName);
const existingIndex = selectedModels.findIndex(
(m) => modelKey(m.provider, m.modelName) === key
);
if (existingIndex >= 0) {
onRemove(existingIndex);
} else if (!atMax) {
onAdd(model);
}
};
const handleOpenChange = (nextOpen: boolean) => {
setOpen(nextOpen);
if (!nextOpen) setReplacingIndex(null);
};
const handlePillClick = (index: number, element: HTMLElement) => {
anchorRef.current = element;
setReplacingIndex(index);
setOpen(true);
};
return (
<Popover open={open} onOpenChange={handleOpenChange}>
<div className="flex items-center justify-end gap-1 p-1">
{!atMax && (
<Button
prominence="tertiary"
icon={SvgPlusCircle}
size="sm"
tooltip="Add Model"
onClick={(e: React.MouseEvent) => {
anchorRef.current = e.currentTarget as HTMLElement;
setReplacingIndex(null);
setOpen(true);
}}
/>
)}
<Popover.Anchor
virtualRef={anchorRef as React.RefObject<HTMLElement>}
/>
{selectedModels.length > 0 && (
<>
{!atMax && (
<Separator
orientation="vertical"
paddingXRem={0.5}
paddingYRem={0.5}
/>
)}
<div className="flex items-center">
{selectedModels.map((model, index) => {
const ProviderIcon = getProviderIcon(
model.provider,
model.modelName
);
if (!isMultiModel) {
return (
<OpenButton
key={modelKey(model.provider, model.modelName)}
icon={ProviderIcon}
onClick={(e: React.MouseEvent) =>
handlePillClick(index, e.currentTarget as HTMLElement)
}
>
{model.displayName}
</OpenButton>
);
}
return (
<div
key={modelKey(model.provider, model.modelName)}
className="flex items-center"
>
{index > 0 && (
<Separator
orientation="vertical"
paddingXRem={0.5}
className="h-5"
/>
)}
<SelectButton
icon={ProviderIcon}
rightIcon={SvgX}
state="empty"
variant="select-tinted"
interaction="hover"
size="lg"
onClick={(e: React.MouseEvent) => {
const target = e.target as HTMLElement;
const btn = e.currentTarget as HTMLElement;
const icons = btn.querySelectorAll(
".interactive-foreground-icon"
);
const lastIcon = icons[icons.length - 1];
if (lastIcon && lastIcon.contains(target)) {
onRemove(index);
} else {
handlePillClick(index, btn);
}
}}
>
{model.displayName}
</SelectButton>
</div>
);
})}
</div>
</>
)}
</div>
<Popover.Content
side="top"
align="start"
width="lg"
avoidCollisions={false}
>
<ModelListContent
llmProviders={llmManager.llmProviders}
isLoading={llmManager.isLoadingProviders}
onSelect={handleSelect}
isSelected={isSelected}
isDisabled={isDisabled}
/>
</Popover.Content>
</Popover>
);
}

View File

@@ -1008,7 +1008,7 @@ function ChatPreferencesForm() {
)}
</Text>
</Section>
<OpalCard backgroundVariant="none" borderVariant="solid">
<OpalCard background="none" border="solid" padding="sm">
<Content
sizePreset="main-ui"
icon={SvgAlertCircle}

View File

@@ -112,7 +112,7 @@ export default function CodeInterpreterPage() {
<SettingsLayouts.Body>
{isEnabled || isLoading ? (
<Hoverable.Root group="code-interpreter/Card">
<SelectCard variant="select-card" state="filled" sizeVariant="lg">
<SelectCard state="filled" padding="sm" rounding="lg">
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
@@ -157,9 +157,9 @@ export default function CodeInterpreterPage() {
</Hoverable.Root>
) : (
<SelectCard
variant="select-card"
state="empty"
sizeVariant="lg"
padding="sm"
rounding="lg"
onClick={() => handleToggle(true)}
>
<CardHeaderLayout

View File

@@ -248,9 +248,9 @@ export default function ImageGenerationContent() {
group="image-gen/ProviderCard"
>
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
padding="sm"
rounding="lg"
aria-label={`image-gen-provider-${provider.image_provider_id}`}
onClick={
isDisconnected

View File

@@ -212,9 +212,9 @@ function ExistingProviderCard({
<Hoverable.Root group="ExistingProviderCard">
<SelectCard
variant="select-card"
state="filled"
sizeVariant="lg"
padding="sm"
rounding="lg"
onClick={() => setIsOpen(true)}
>
<CardHeaderLayout
@@ -287,9 +287,9 @@ function NewProviderCard({
return (
<SelectCard
variant="select-card"
state="empty"
sizeVariant="lg"
padding="sm"
rounding="lg"
onClick={() => setIsOpen(true)}
>
<CardHeaderLayout
@@ -331,9 +331,9 @@ function NewCustomProviderCard({
return (
<SelectCard
variant="select-card"
state="empty"
sizeVariant="lg"
padding="sm"
rounding="lg"
onClick={() => setIsOpen(true)}
>
<CardHeaderLayout

View File

@@ -264,9 +264,9 @@ function ProviderCard({
return (
<Hoverable.Root group="web-search/ProviderCard">
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
padding="sm"
rounding="lg"
onClick={
isDisconnected && onConnect
? onConnect

View File

@@ -67,7 +67,7 @@ export default function AdminListHeader({
if (!hasItems) {
return (
<Card paddingVariant="md" roundingVariant="lg" borderVariant="solid">
<Card rounding="lg" border="solid">
<div className="flex flex-row items-center justify-between gap-3">
<Content
title={emptyStateText}

View File

@@ -86,9 +86,9 @@ export default function ProviderCard({
return (
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
padding="sm"
rounding="lg"
aria-label={ariaLabel}
onClick={isDisconnected && onConnect ? onConnect : undefined}
>

View File

@@ -5,7 +5,11 @@ import { useSWRConfig } from "swr";
import { Formik } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
import {
LLMProviderFormProps,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
@@ -91,15 +95,27 @@ export default function AzureModal({
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
const [addedModels, setAddedModels] = useState<ModelConfiguration[]>([]);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const onClose = () => {
setAddedModels([]);
onOpenChange?.(false);
};
const modelConfigurations = buildAvailableModelConfigurations(
const baseModelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
// Merge base models with any user-added models (dedup by name)
const existingNames = new Set(baseModelConfigurations.map((m) => m.name));
const modelConfigurations = [
...baseModelConfigurations,
...addedModels.filter((m) => !existingNames.has(m.name)),
];
const initialValues: AzureModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
@@ -224,6 +240,25 @@ export default function AzureModal({
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onAddModel={(modelName) => {
const newModel: ModelConfiguration = {
name: modelName,
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
};
setAddedModels((prev) => [...prev, newModel]);
const currentSelected =
formikProps.values.selected_model_names ?? [];
formikProps.setFieldValue("selected_model_names", [
...currentSelected,
modelName,
]);
if (!formikProps.values.default_model_name) {
formikProps.setFieldValue("default_model_name", modelName);
}
}}
/>
)}

View File

@@ -225,11 +225,7 @@ function BedrockModalInternals({
</FieldWrapper>
{authMethod === AUTH_METHOD_ACCESS_KEY && (
<Card
backgroundVariant="light"
borderVariant="none"
paddingVariant="sm"
>
<Card background="light" border="none" padding="sm">
<Section gap={1}>
<InputLayouts.Vertical
name={FIELD_AWS_ACCESS_KEY_ID}
@@ -255,7 +251,7 @@ function BedrockModalInternals({
{authMethod === AUTH_METHOD_IAM && (
<FieldWrapper>
<Card backgroundVariant="none" borderVariant="solid">
<Card background="none" border="solid" padding="sm">
<Content
icon={SvgAlertCircle}
title="Onyx will use the IAM role attached to the environment its running in to authenticate."
@@ -267,11 +263,7 @@ function BedrockModalInternals({
)}
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
<Card
backgroundVariant="light"
borderVariant="none"
paddingVariant="sm"
>
<Card background="light" border="none" padding="sm">
<Section gap={0.5}>
<InputLayouts.Vertical
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}

View File

@@ -166,7 +166,7 @@ function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
))}
</div>
) : (
<EmptyMessageCard title="No models added yet." />
<EmptyMessageCard title="No models added yet." padding="sm" />
)}
<Button
@@ -393,7 +393,7 @@ export default function CustomModal({
/>
</FieldWrapper>
<Card>
<Card padding="sm">
<ModelConfigurationList formikProps={formikProps as any} />
</Card>
</Section>

View File

@@ -140,7 +140,7 @@ function OllamaModalInternals({
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<Card backgroundVariant="light" borderVariant="none" paddingVariant="sm">
<Card background="light" border="none" padding="sm">
<Tabs defaultValue={defaultTab}>
<Tabs.List>
<Tabs.Trigger value={TAB_SELF_HOSTED}>

View File

@@ -1,6 +1,6 @@
"use client";
import { ReactNode } from "react";
import { ReactNode, useState } from "react";
import { Form, FormikProps } from "formik";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useAgents } from "@/hooks/useAgents";
@@ -9,6 +9,7 @@ import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
import * as InputLayouts from "@/layouts/input-layouts";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
@@ -25,6 +26,7 @@ import {
SvgArrowExchange,
SvgOnyxOctagon,
SvgOrganization,
SvgPlusCircle,
SvgRefreshCw,
SvgSparkle,
SvgUserManage,
@@ -250,11 +252,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
</FieldWrapper>
{!isPublic && (
<Card
backgroundVariant="light"
borderVariant="none"
paddingVariant="sm"
>
<Card background="light" border="none" padding="sm">
<Section gap={0.5}>
<InputComboBox
placeholder="Add groups and agents"
@@ -266,7 +264,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
leftSearchIcon
/>
<Card backgroundVariant="heavy" borderVariant="none">
<Card background="heavy" border="none" padding="sm">
<ContentAction
icon={SvgUserManage}
title="Admin"
@@ -290,7 +288,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
const memberCount = group?.users.length ?? 0;
return (
<div key={`group-${id}`} className="min-w-0">
<Card backgroundVariant="heavy" borderVariant="none">
<Card background="heavy" border="none" padding="sm">
<ContentAction
icon={SvgUsers}
title={group?.name ?? `Group ${id}`}
@@ -325,7 +323,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
const agent = agentMap.get(id);
return (
<div key={`agent-${id}`} className="min-w-0">
<Card backgroundVariant="heavy" borderVariant="none">
<Card background="heavy" border="none" padding="sm">
<ContentAction
icon={
agent
@@ -379,6 +377,8 @@ export interface ModelsFieldProps<T> {
shouldShowAutoUpdateToggle: boolean;
/** Called when the user clicks the refresh button to re-fetch models. */
onRefetch?: () => Promise<void> | void;
/** Called when the user adds a custom model by name. Enables the "Add Model" input. */
onAddModel?: (modelName: string) => void;
}
export function ModelsField<T extends BaseLLMFormValues>({
@@ -387,7 +387,9 @@ export function ModelsField<T extends BaseLLMFormValues>({
recommendedDefaultModel,
shouldShowAutoUpdateToggle,
onRefetch,
onAddModel,
}: ModelsFieldProps<T>) {
const [newModelName, setNewModelName] = useState("");
const isAutoMode = formikProps.values.is_auto_mode;
const selectedModels = formikProps.values.selected_model_names ?? [];
const defaultModel = formikProps.values.default_model_name;
@@ -452,7 +454,7 @@ export function ModelsField<T extends BaseLLMFormValues>({
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
return (
<Card backgroundVariant="light" borderVariant="none" paddingVariant="sm">
<Card background="light" border="none" padding="sm">
<Section gap={0.5}>
<InputLayouts.Horizontal
title="Models"
@@ -491,7 +493,7 @@ export function ModelsField<T extends BaseLLMFormValues>({
</InputLayouts.Horizontal>
{modelConfigurations.length === 0 ? (
<EmptyMessageCard title="No models available." />
<EmptyMessageCard title="No models available." padding="sm" />
) : (
<Section gap={0.25}>
{isAutoMode
@@ -578,6 +580,50 @@ export function ModelsField<T extends BaseLLMFormValues>({
</Section>
)}
{onAddModel && !isAutoMode && (
<Section flexDirection="row" gap={0.5}>
<div className="flex-1">
<InputTypeIn
placeholder="Enter model name"
value={newModelName}
onChange={(e) => setNewModelName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && newModelName.trim()) {
e.preventDefault();
const trimmed = newModelName.trim();
if (!modelConfigurations.some((m) => m.name === trimmed)) {
onAddModel(trimmed);
setNewModelName("");
}
}
}}
showClearButton={false}
/>
</div>
<Button
prominence="secondary"
icon={SvgPlusCircle}
type="button"
disabled={
!newModelName.trim() ||
modelConfigurations.some((m) => m.name === newModelName.trim())
}
onClick={() => {
const trimmed = newModelName.trim();
if (
trimmed &&
!modelConfigurations.some((m) => m.name === trimmed)
) {
onAddModel(trimmed);
setNewModelName("");
}
}}
>
Add Model
</Button>
</Section>
)}
{shouldShowAutoUpdateToggle && (
<InputLayouts.Horizontal
title="Auto Update"