Compare commits

...

4 Commits

Author SHA1 Message Date
dependabot[bot]
36196373a8 chore(deps): bump hono from 4.12.5 to 4.12.7 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9263)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-10 18:54:17 -07:00
Jamison Lahman
533aa8eff8 chore(release): upgrade release-tag (#9257) 2026-03-11 00:50:55 +00:00
Raunak Bhagat
ecbb267f80 fix: Consolidate search state-machine (#9234) 2026-03-11 00:42:39 +00:00
Danelegend
66023dbb6d feat(llm-provider): fetch litellm models (#8418) 2026-03-10 23:48:56 +00:00
23 changed files with 747 additions and 376 deletions

View File

@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
@@ -369,9 +370,9 @@ def upsert_llm_provider(
def sync_model_configurations(
db_session: Session,
provider_name: str,
models: list[dict],
models: list[SyncModelEntry],
) -> int:
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
This inserts NEW models from the source API without overwriting existing ones.
User preferences (is_visible, max_input_tokens) are preserved for existing models.
@@ -379,7 +380,7 @@ def sync_model_configurations(
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
models: List of SyncModelEntry objects describing the fetched models
Returns:
Number of new models added
@@ -393,21 +394,20 @@ def sync_model_configurations(
new_count = 0
for model in models:
model_name = model["name"]
if model_name not in existing_names:
if model.name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
supported_flows = [LLMModelFlowType.CHAT]
if model.get("supports_image_input", False):
if model.supports_image_input:
supported_flows.append(LLMModelFlowType.VISION)
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model_name,
model_name=model.name,
supported_flows=supported_flows,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
display_name=model.get("display_name"),
max_input_tokens=model.max_input_tokens,
display_name=model.display_name,
)
new_count += 1

View File

@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"version": "4.12.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"

View File

@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.server.manage.llm.utils import generate_bedrock_display_name
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _sync_fetched_models(
db_session: Session,
provider_name: str,
models: list[SyncModelEntry],
source_label: str,
) -> None:
"""Sync fetched models to DB for the given provider.
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of SyncModelEntry objects describing the fetched models
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
"""
try:
new_count = sync_model_configurations(
db_session=db_session,
provider_name=provider_name,
models=models,
)
if new_count > 0:
logger.info(
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
# Keys in custom_config that contain sensitive credentials
_SENSITIVE_CONFIG_KEYS = {
"vertex_credentials",
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
for r in results
],
source_label="Bedrock",
)
return results
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
for r in sorted_results
],
source_label="Ollama",
)
return sorted_results
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
for r in sorted_results
],
source_label="OpenRouter",
)
return sorted_results
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
for r in sorted_results
],
source_label="LM Studio",
)
return sorted_results
@admin_router.post("/litellm/available-models")
def get_litellm_available_models(
request: LitellmModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
response_json = _get_litellm_models_response(
api_key=request.api_key, api_base=request.api_base
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Litellm endpoint",
)
results: list[LitellmFinalModelResponse] = []
for model in models:
try:
model_details = LitellmModelDetails.model_validate(model)
results.append(
LitellmFinalModelResponse(
provider_name=model_details.owned_by,
model_name=model_details.id,
)
)
except Exception as e:
logger.warning(
"Failed to parse Litellm model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Litellm",
)
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.model_name,
display_name=r.model_name,
)
for r in sorted_results
],
source_label="LiteLLM",
)
return sorted_results
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"LiteLLM models endpoint not found at {url}. "
"Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)

View File

@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
default_text=default_text,
default_vision=default_vision,
)
class SyncModelEntry(BaseModel):
"""Typed model for syncing fetched models to the DB."""
name: str
display_name: str
max_input_tokens: int | None = None
supports_image_input: bool = False
class LitellmModelsRequest(BaseModel):
api_key: str
api_base: str
provider_name: str | None = None # Optional: to save models to existing provider
class LitellmModelDetails(BaseModel):
"""Response model for Litellm proxy /api/v1/models endpoint"""
id: str # Model ID (e.g. "gpt-4o")
object: str # "model"
created: int # Unix timestamp in seconds
owned_by: str # Provider name (e.g. "openai")
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")

View File

@@ -406,7 +406,7 @@ referencing==0.36.2
# jsonschema-specifications
regex==2025.11.3
# via tiktoken
release-tag==0.4.3
release-tag==0.5.2
# via onyx
reorder-python-imports-black==3.14.0
# via onyx

View File

@@ -7,6 +7,7 @@ import pytest
from onyx.db.llm import sync_model_configurations
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import SyncModelEntry
class TestSyncModelConfigurations:
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4",
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o",
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4",
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o",
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Existing - should be skipped
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o", # New - should be inserted
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Existing - should be skipped
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o", # New - should be inserted
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Already exists
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Already exists
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
sync_model_configurations(
db_session=mock_session,
provider_name="nonexistent",
models=[{"name": "model", "display_name": "Model"}],
models=[SyncModelEntry(name="model", display_name="Model")],
)
def test_handles_missing_optional_fields(self) -> None:
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
with patch(
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
# Model with only required fields
# Model with only required fields (max_input_tokens and supports_image_input default)
models = [
{
"name": "model-1",
# No display_name, max_input_tokens, or supports_image_input
},
SyncModelEntry(
name="model-1",
display_name="Model 1",
),
]
result = sync_model_configurations(

View File

@@ -1,15 +1,19 @@
"""Tests for LLM model fetch endpoints.
These tests verify the full request/response flow for fetching models
from dynamic providers (Ollama, OpenRouter), including the
from dynamic providers (Ollama, OpenRouter, Litellm), including the
sync-to-DB behavior when provider_name is specified.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
from onyx.server.manage.llm.models import LMStudioModelsRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
request = LMStudioModelsRequest(api_base="http://localhost:1234")
with pytest.raises(OnyxError):
get_lm_studio_available_models(request, MagicMock(), mock_session)
class TestGetLitellmAvailableModels:
"""Tests for the Litellm proxy model fetch endpoint."""
@pytest.fixture
def mock_litellm_response(self) -> dict:
"""Mock response from Litellm /v1/models endpoint."""
return {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
{
"id": "claude-3-5-sonnet",
"object": "model",
"created": 1700000001,
"owned_by": "anthropic",
},
{
"id": "gemini-pro",
"object": "model",
"created": 1700000002,
"owned_by": "google",
},
]
}
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
"""Test that endpoint returns properly formatted model list."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 3
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
"""Test that provider_name and model_name are correctly extracted."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
gpt = next(r for r in results if r.model_name == "gpt-4o")
assert gpt.provider_name == "openai"
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
assert claude.provider_name == "anthropic"
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
"""Test that results are alphabetically sorted by model_name."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
model_names = [r.model_name for r in results]
assert model_names == sorted(model_names, key=str.lower)
def test_empty_data_raises_onyx_error(self) -> None:
"""Test that empty model list raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"data": []}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No models found"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_missing_data_key_raises_onyx_error(self) -> None:
"""Test that response without 'data' key raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_skips_unparseable_entries(self) -> None:
"""Test that malformed model entries are skipped without failing."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_with_bad_entry = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
# Missing required fields
{"bad_field": "bad_value"},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_with_bad_entry
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 1
assert results[0].model_name == "gpt-4o"
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
"""Test that OnyxError is raised when all entries fail to parse."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_all_bad = {
"data": [
{"bad_field": "bad_value"},
{"another_bad": 123},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_all_bad
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No compatible models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_api_base_trailing_slash_handled(self) -> None:
"""Test that trailing slashes in api_base are handled correctly."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
mock_litellm_response = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000/",
api_key="test-key",
)
get_litellm_available_models(request, MagicMock(), mock_session)
# Should call /v1/models without double slashes
call_args = mock_get.call_args
assert call_args[0][0] == "http://localhost:4000/v1/models"
def test_connection_failure_raises_onyx_error(self) -> None:
"""Test that connection failures are wrapped in OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_get.side_effect = Exception("Connection refused")
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_401_raises_authentication_error(self) -> None:
"""Test that a 401 response raises OnyxError with authentication message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 401
mock_get.side_effect = httpx.HTTPStatusError(
"Unauthorized", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="bad-key",
)
with pytest.raises(OnyxError, match="Authentication failed"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_404_raises_not_found_error(self) -> None:
"""Test that a 404 response raises OnyxError with endpoint not found message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 404
mock_get.side_effect = httpx.HTTPStatusError(
"Not Found", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="endpoint not found"):
get_litellm_available_models(request, MagicMock(), mock_session)

View File

@@ -153,7 +153,7 @@ dev = [
"pytest-repeat==0.9.4",
"pytest-xdist==3.8.0",
"pytest==8.3.5",
"release-tag==0.4.3",
"release-tag==0.5.2",
"reorder-python-imports-black==3.14.0",
"ruff==0.12.0",
"types-beautifulsoup4==4.12.0.3",

18
uv.lock generated
View File

@@ -4485,7 +4485,7 @@ requires-dist = [
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
@@ -6338,16 +6338,16 @@ wheels = [
[[package]]
name = "release-tag"
version = "0.4.3"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
]
[[package]]

View File

@@ -11,7 +11,7 @@ import { Button } from "@opal/components";
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
import type { AppMode } from "@/providers/QueryControllerProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
@@ -58,15 +58,15 @@ const footerMarkdownComponents = {
*/
export default function NRFChrome() {
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const { appMode, setAppMode } = useAppMode();
const { state, setAppMode } = useQueryController();
const settings = useSettingsContext();
const { isMobile } = useScreenSize();
const { setFolded } = useAppSidebarContext();
const appFocus = useAppFocus();
const { classification } = useQueryController();
const [modePopoverOpen, setModePopoverOpen] = useState(false);
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
const effectiveMode: AppMode =
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
const customFooterContent =
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
@@ -78,7 +78,7 @@ export default function NRFChrome() {
isPaidEnterpriseFeaturesEnabled &&
settings.isSearchModeAvailable &&
appFocus.isNewSession() &&
!classification;
state.phase === "idle";
const showHeader = isMobile || showModeToggle;

View File

@@ -175,7 +175,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
const isStreaming = currentChatState === "streaming";
// Query controller for search/chat classification (EE feature)
const { submit: submitQuery, classification } = useQueryController();
const { submit: submitQuery, state } = useQueryController();
// Determine if retrieval (search) is enabled based on the agent
const retrievalEnabled = useMemo(() => {
@@ -186,7 +186,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
}, [liveAgent]);
// Check if we're in search mode
const isSearch = classification === "search";
const isSearch =
state.phase === "searching" || state.phase === "search-results";
// Anchor for scroll positioning (matches ChatPage pattern)
const anchorMessage = messageHistory.at(-2) ?? messageHistory[0];
@@ -317,7 +318,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
};
// Use submitQuery which will classify the query and either:
// - Route to search (sets classification to "search" and shows SearchUI)
// - Route to search (sets phase to "searching"/"search-results" and shows SearchUI)
// - Route to chat (calls onChat callback)
await submitQuery(submittedMessage, onChat);
},

View File

@@ -1,55 +0,0 @@
"use client";
import React, { useState, useCallback, useEffect } from "react";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { AppModeContext, AppMode } from "@/providers/AppModeProvider";
import { useUser } from "@/providers/UserProvider";
import { useSettingsContext } from "@/providers/SettingsProvider";
export interface AppModeProviderProps {
children: React.ReactNode;
}
/**
* Provider for application mode (Search/Chat).
*
* This controls how user queries are handled:
* - **search**: Forces search mode - quick document lookup
* - **chat**: Forces chat mode - conversation with follow-up questions
*
* The initial mode is read from the user's persisted `default_app_mode` preference.
* When search mode is unavailable (admin setting or no connectors), the mode is locked to "chat".
*/
export function AppModeProvider({ children }: AppModeProviderProps) {
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const { user } = useUser();
const { isSearchModeAvailable } = useSettingsContext();
const persistedMode = user?.preferences?.default_app_mode;
const [appMode, setAppModeState] = useState<AppMode>("chat");
useEffect(() => {
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) {
setAppModeState("chat");
return;
}
if (persistedMode) {
setAppModeState(persistedMode.toLowerCase() as AppMode);
}
}, [isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable, persistedMode]);
const setAppMode = useCallback(
(mode: AppMode) => {
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) return;
setAppModeState(mode);
},
[isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable]
);
return (
<AppModeContext.Provider value={{ appMode, setAppMode }}>
{children}
</AppModeContext.Provider>
);
}

View File

@@ -8,14 +8,15 @@ import {
SearchFullResponse,
} from "@/lib/search/interfaces";
import { classifyQuery, searchDocuments } from "@/ee/lib/search/svc";
import { useAppMode } from "@/providers/AppModeProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { useUser } from "@/providers/UserProvider";
import {
QueryControllerContext,
QueryClassification,
QueryControllerValue,
QueryState,
AppMode,
} from "@/providers/QueryControllerProvider";
interface QueryControllerProviderProps {
@@ -25,19 +26,53 @@ interface QueryControllerProviderProps {
export function QueryControllerProvider({
children,
}: QueryControllerProviderProps) {
const { appMode, setAppMode } = useAppMode();
const appFocus = useAppFocus();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const settings = useSettingsContext();
const { isSearchModeAvailable: searchUiEnabled } = settings;
const { user } = useUser();
// Query state
// ── Merged query state (discriminated union) ──────────────────────────
const [state, setState] = useState<QueryState>({
phase: "idle",
appMode: "chat",
});
// Persistent app-mode preference — survives phase transitions and is
// used to restore the correct mode when resetting back to idle.
const appModeRef = useRef<AppMode>("chat");
// ── App mode sync from user preferences ───────────────────────────────
const persistedMode = user?.preferences?.default_app_mode;
useEffect(() => {
let mode: AppMode = "chat";
if (isPaidEnterpriseFeaturesEnabled && searchUiEnabled && persistedMode) {
const lower = persistedMode.toLowerCase();
mode = (["auto", "search", "chat"] as const).includes(lower as AppMode)
? (lower as AppMode)
: "chat";
}
appModeRef.current = mode;
setState((prev) =>
prev.phase === "idle" ? { phase: "idle", appMode: mode } : prev
);
}, [isPaidEnterpriseFeaturesEnabled, searchUiEnabled, persistedMode]);
const setAppMode = useCallback(
(mode: AppMode) => {
if (!isPaidEnterpriseFeaturesEnabled || !searchUiEnabled) return;
setState((prev) => {
if (prev.phase !== "idle") return prev;
appModeRef.current = mode;
return { phase: "idle", appMode: mode };
});
},
[isPaidEnterpriseFeaturesEnabled, searchUiEnabled]
);
// ── Ancillary state ───────────────────────────────────────────────────
const [query, setQuery] = useState<string | null>(null);
const [classification, setClassification] =
useState<QueryClassification>(null);
const [isClassifying, setIsClassifying] = useState(false);
// Search state
const [searchResults, setSearchResults] = useState<SearchDocWithContent[]>(
[]
);
@@ -51,7 +86,7 @@ export function QueryControllerProvider({
const searchAbortRef = useRef<AbortController | null>(null);
/**
* Perform document search
* Perform document search (pure data-fetching, no phase side effects)
*/
const performSearch = useCallback(
async (searchQuery: string, filters?: BaseFilters): Promise<void> => {
@@ -85,19 +120,15 @@ export function QueryControllerProvider({
setLlmSelectedDocIds(response.llm_selected_doc_ids ?? null);
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
return;
throw err;
}
setError("Document search failed. Please try again.");
setSearchResults([]);
setLlmSelectedDocIds(null);
} finally {
// After we've performed a search, we automatically switch to "search" mode.
// This is a "sticky" implementation; on purpose.
setAppMode("search");
}
},
[setAppMode]
[]
);
/**
@@ -112,8 +143,6 @@ export function QueryControllerProvider({
const controller = new AbortController();
classifyAbortRef.current = controller;
setIsClassifying(true);
try {
const response: SearchFlowClassificationResponse = await classifyQuery(
classifyQueryText,
@@ -129,8 +158,6 @@ export function QueryControllerProvider({
setError("Query classification failed. Falling back to chat.");
return "chat";
} finally {
setIsClassifying(false);
}
},
[]
@@ -148,62 +175,51 @@ export function QueryControllerProvider({
setQuery(submitQuery);
setError(null);
// 1.
// We always route through chat if we're not Enterprise Enabled.
//
// 2.
// We always route through chat if the admin has disabled the Search UI.
//
// 3.
// We only go down the classification route if we're in the "New Session" tab.
// Everywhere else, we always use the chat-flow.
//
// 4.
// If we're in the "New Session" tab and the app-mode is "Chat", we continue with the chat-flow anyways.
const currentAppMode = appModeRef.current;
// Always route through chat if:
// 1. Not Enterprise Enabled
// 2. Admin has disabled the Search UI
// 3. Not in the "New Session" tab
// 4. In "New Session" tab but app-mode is "Chat"
if (
!isPaidEnterpriseFeaturesEnabled ||
!searchUiEnabled ||
!appFocus.isNewSession() ||
appMode === "chat"
currentAppMode === "chat"
) {
setClassification("chat");
setState({ phase: "chat" });
setSearchResults([]);
setLlmSelectedDocIds(null);
onChat(submitQuery);
return;
}
if (appMode === "search") {
await performSearch(submitQuery, filters);
setClassification("search");
// Search mode: immediately show SearchUI with loading state
if (currentAppMode === "search") {
setState({ phase: "searching" });
try {
await performSearch(submitQuery, filters);
} catch (err) {
if (err instanceof Error && err.name === "AbortError") return;
throw err;
}
setState({ phase: "search-results" });
return;
}
// # Note (@raunakab)
//
// Interestingly enough, for search, we do:
// 1. setClassification("search")
// 2. performSearch
//
// But for chat, we do:
// 1. performChat
// 2. setClassification("chat")
//
// The ChatUI has a nice loading UI, so it's fine for us to prematurely set the
// classification-state before the chat has finished loading.
//
// However, the SearchUI does not. Prematurely setting the classification-state
// will lead to a slightly ugly UI.
// Auto mode: classify first, then route
setState({ phase: "classifying" });
try {
const result = await performClassification(submitQuery);
if (result === "search") {
setState({ phase: "searching" });
await performSearch(submitQuery, filters);
setClassification("search");
setState({ phase: "search-results" });
appModeRef.current = "search";
} else {
setClassification("chat");
setState({ phase: "chat" });
setSearchResults([]);
setLlmSelectedDocIds(null);
onChat(submitQuery);
@@ -213,14 +229,13 @@ export function QueryControllerProvider({
return;
}
setClassification("chat");
setState({ phase: "chat" });
setSearchResults([]);
setLlmSelectedDocIds(null);
onChat(submitQuery);
}
},
[
appMode,
appFocus,
performClassification,
performSearch,
@@ -235,7 +250,14 @@ export function QueryControllerProvider({
const refineSearch = useCallback(
async (filters: BaseFilters): Promise<void> => {
if (!query) return;
await performSearch(query, filters);
setState({ phase: "searching" });
try {
await performSearch(query, filters);
} catch (err) {
if (err instanceof Error && err.name === "AbortError") return;
throw err;
}
setState({ phase: "search-results" });
},
[query, performSearch]
);
@@ -254,7 +276,7 @@ export function QueryControllerProvider({
}
setQuery(null);
setClassification(null);
setState({ phase: "idle", appMode: appModeRef.current });
setSearchResults([]);
setLlmSelectedDocIds(null);
setError(null);
@@ -262,8 +284,8 @@ export function QueryControllerProvider({
const value: QueryControllerValue = useMemo(
() => ({
classification,
isClassifying,
state,
setAppMode,
searchResults,
llmSelectedDocIds,
error,
@@ -272,8 +294,8 @@ export function QueryControllerProvider({
reset,
}),
[
classification,
isClassifying,
state,
setAppMode,
searchResults,
llmSelectedDocIds,
error,
@@ -283,7 +305,7 @@ export function QueryControllerProvider({
]
);
// Sync classification state with navigation context
// Sync state with navigation context
useEffect(reset, [appFocus, reset]);
return (

View File

@@ -56,7 +56,7 @@ export default function SearchCard({
return (
<Interactive.Stateless onClick={handleClick} prominence="secondary">
<Interactive.Container heightVariant="fit">
<Interactive.Container heightVariant="fit" widthVariant="full">
<Section alignItems="start" gap={0} padding={0.25}>
{/* Title Row */}
<Section

View File

@@ -18,16 +18,17 @@ import { getTimeFilterDate, TimeFilter } from "@/lib/time";
import useTags from "@/hooks/useTags";
import { SourceIcon } from "@/components/SourceIcon";
import Text from "@/refresh-components/texts/Text";
import LineItem from "@/refresh-components/buttons/LineItem";
import { Section } from "@/layouts/general-layouts";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { SvgCheck, SvgClock, SvgTag } from "@opal/icons";
import FilterButton from "@/refresh-components/buttons/FilterButton";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useFilter from "@/hooks/useFilter";
import { LineItemButton } from "@opal/components";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { cn } from "@/lib/utils";
import { toast } from "@/hooks/useToast";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
// ============================================================================
// Types
@@ -51,22 +52,17 @@ const TIME_FILTER_OPTIONS: { value: TimeFilter; label: string }[] = [
{ value: "year", label: "Past year" },
];
// ============================================================================
// SearchResults Component (default export)
// ============================================================================
/**
* Component for displaying search results with source filter sidebar.
*/
export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
// Available tags from backend
const { tags: availableTags } = useTags();
const {
state,
searchResults: results,
llmSelectedDocIds,
error,
refineSearch: onRefineSearch,
} = useQueryController();
const prevErrorRef = useRef<string | null>(null);
// Show a toast notification when a new error occurs
@@ -197,6 +193,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
const showEmpty = !error && results.length === 0;
// Show a centered spinner while search is in-flight (after all hooks)
if (state.phase === "searching") {
return (
<div className="flex-1 min-h-0 w-full flex items-center justify-center">
<SimpleLoader />
</div>
);
}
return (
<div className="flex-1 min-h-0 w-full flex flex-col gap-3">
{/* ── Top row: Filters + Result count ── */}
@@ -226,18 +231,19 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
<Popover.Content align="start" width="md">
<PopoverMenu>
{TIME_FILTER_OPTIONS.map((opt) => (
<LineItem
<LineItemButton
key={opt.value}
onClick={() => {
setTimeFilter(opt.value);
setTimeFilterOpen(false);
onRefineSearch(buildFilters({ time: opt.value }));
}}
selected={timeFilter === opt.value}
state={timeFilter === opt.value ? "selected" : "empty"}
icon={timeFilter === opt.value ? SvgCheck : SvgClock}
>
{opt.label}
</LineItem>
title={opt.label}
sizePreset="main-ui"
variant="section"
/>
))}
</PopoverMenu>
</Popover.Content>
@@ -278,7 +284,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
t.tag_value === tag.tag_value
);
return (
<LineItem
<LineItemButton
key={`${tag.tag_key}=${tag.tag_value}`}
onClick={() => {
const next = isSelected
@@ -291,11 +297,12 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
setSelectedTags(next);
onRefineSearch(buildFilters({ tags: next }));
}}
selected={isSelected}
state={isSelected ? "selected" : "empty"}
icon={isSelected ? SvgCheck : SvgTag}
>
{tag.tag_value}
</LineItem>
title={tag.tag_value}
sizePreset="main-ui"
variant="section"
/>
);
})}
</PopoverMenu>
@@ -357,7 +364,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
<div className="flex-1 min-h-0 overflow-y-auto flex flex-col gap-4 px-1">
<Section gap={0.25} height="fit">
{sourcesWithMeta.map(({ source, meta, count }) => (
<LineItem
<LineItemButton
key={source}
icon={(props) => (
<SourceIcon
@@ -367,12 +374,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
/>
)}
onClick={() => handleSourceToggle(source)}
selected={selectedSources.includes(source)}
emphasized
state={
selectedSources.includes(source) ? "selected" : "empty"
}
title={meta.displayName}
selectVariant="select-heavy"
sizePreset="main-ui"
variant="section"
rightChildren={<Text text03>{count}</Text>}
>
{meta.displayName}
</LineItem>
/>
))}
</Section>
</div>

View File

@@ -5,6 +5,7 @@
//
// This is useful in determining what `SidebarTab` should be active, for example.
import { useMemo } from "react";
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
import { usePathname, useSearchParams } from "next/navigation";
@@ -66,31 +67,25 @@ export default function useAppFocus(): AppFocus {
const pathname = usePathname();
const searchParams = useSearchParams();
// Check if we're viewing a shared chat
if (pathname.startsWith("/app/shared/")) {
return new AppFocus("shared-chat");
}
// Check if we're on the user settings page
if (pathname.startsWith("/app/settings")) {
return new AppFocus("user-settings");
}
// Check if we're on the agents page
if (pathname.startsWith("/app/agents")) {
return new AppFocus("more-agents");
}
// Check search params for chat, agent, or project
const chatId = searchParams.get(SEARCH_PARAM_NAMES.CHAT_ID);
if (chatId) return new AppFocus({ type: "chat", id: chatId });
const agentId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
if (agentId) return new AppFocus({ type: "agent", id: agentId });
const projectId = searchParams.get(SEARCH_PARAM_NAMES.PROJECT_ID);
if (projectId) return new AppFocus({ type: "project", id: projectId });
// No search params means we're on a new session
return new AppFocus("new-session");
// Memoize on the values that determine which AppFocus is constructed.
// AppFocus is immutable, so same inputs → same instance.
return useMemo(() => {
if (pathname.startsWith("/app/shared/")) {
return new AppFocus("shared-chat");
}
if (pathname.startsWith("/app/settings")) {
return new AppFocus("user-settings");
}
if (pathname.startsWith("/app/agents")) {
return new AppFocus("more-agents");
}
if (chatId) return new AppFocus({ type: "chat", id: chatId });
if (agentId) return new AppFocus({ type: "agent", id: agentId });
if (projectId) return new AppFocus({ type: "project", id: projectId });
return new AppFocus("new-session");
}, [pathname, chatId, agentId, projectId]);
}

View File

@@ -60,7 +60,7 @@ import {
} from "@opal/icons";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
import type { AppMode } from "@/providers/QueryControllerProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
@@ -82,7 +82,7 @@ import useBrowserInfo from "@/hooks/useBrowserInfo";
*/
function Header() {
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const { appMode, setAppMode } = useAppMode();
const { state, setAppMode } = useQueryController();
const settings = useSettingsContext();
const { isMobile } = useScreenSize();
const { setFolded } = useAppSidebarContext();
@@ -108,7 +108,6 @@ function Header() {
useChatSessions();
const router = useRouter();
const appFocus = useAppFocus();
const { classification } = useQueryController();
const customHeaderContent =
settings?.enterpriseSettings?.custom_header_content;
@@ -117,7 +116,8 @@ function Header() {
// without this content still use.
const pageWithHeaderContent = appFocus.isChat() || appFocus.isNewSession();
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
const effectiveMode: AppMode =
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
const availableProjects = useMemo(() => {
if (!projects) return [];
@@ -323,7 +323,7 @@ function Header() {
{isPaidEnterpriseFeaturesEnabled &&
settings.isSearchModeAvailable &&
appFocus.isNewSession() &&
!classification && (
state.phase === "idle" && (
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
<Popover.Trigger asChild>
<OpenButton

View File

@@ -1,23 +0,0 @@
"use client";
import { createContext, useContext } from "react";
import { eeGated } from "@/ce";
import { AppModeProvider as EEAppModeProvider } from "@/ee/providers/AppModeProvider";
export type AppMode = "auto" | "search" | "chat";
interface AppModeContextValue {
appMode: AppMode;
setAppMode: (mode: AppMode) => void;
}
export const AppModeContext = createContext<AppModeContextValue>({
appMode: "chat",
setAppMode: () => undefined,
});
export function useAppMode(): AppModeContextValue {
return useContext(AppModeContext);
}
export const AppModeProvider = eeGated(EEAppModeProvider);

View File

@@ -24,7 +24,7 @@
* 4. **ProviderContextProvider** - LLM provider configuration
* 5. **ModalProvider** - Global modal state management
* 6. **AppSidebarProvider** - Sidebar open/closed state
* 7. **AppModeProvider** - Search/Chat mode selection
* 7. **QueryControllerProvider** - Search/Chat mode + query lifecycle
*
* ## Usage
*
@@ -40,7 +40,7 @@
* - `useSettingsContext()` - from SettingsProvider
* - `useUser()` - from UserProvider
* - `useAppBackground()` - from AppBackgroundProvider
* - `useAppMode()` - from AppModeProvider
* - `useQueryController()` - from QueryControllerProvider (includes appMode)
* - etc.
*
* @TODO(@raunakab): The providers wrapped by this component are currently
@@ -65,7 +65,6 @@ import { User } from "@/lib/types";
import { ModalProvider } from "@/components/context/ModalContext";
import { AuthTypeMetadata } from "@/lib/userSS";
import { AppSidebarProvider } from "@/providers/AppSidebarProvider";
import { AppModeProvider } from "@/providers/AppModeProvider";
import { AppBackgroundProvider } from "@/providers/AppBackgroundProvider";
import { QueryControllerProvider } from "@/providers/QueryControllerProvider";
import ToastProvider from "@/providers/ToastProvider";
@@ -96,11 +95,9 @@ export default function AppProvider({
<ProviderContextProvider>
<ModalProvider user={user}>
<AppSidebarProvider folded={!!folded}>
<AppModeProvider>
<QueryControllerProvider>
<ToastProvider>{children}</ToastProvider>
</QueryControllerProvider>
</AppModeProvider>
<QueryControllerProvider>
<ToastProvider>{children}</ToastProvider>
</QueryControllerProvider>
</AppSidebarProvider>
</ModalProvider>
</ProviderContextProvider>

View File

@@ -5,13 +5,20 @@ import { eeGated } from "@/ce";
import { QueryControllerProvider as EEQueryControllerProvider } from "@/ee/providers/QueryControllerProvider";
import { SearchDocWithContent, BaseFilters } from "@/lib/search/interfaces";
export type QueryClassification = "search" | "chat" | null;
export type AppMode = "auto" | "search" | "chat";
export type QueryState =
| { phase: "idle"; appMode: AppMode }
| { phase: "classifying" }
| { phase: "searching" }
| { phase: "search-results" }
| { phase: "chat" };
export interface QueryControllerValue {
/** Classification state: null (idle), "search", or "chat" */
classification: QueryClassification;
/** Whether or not the currently submitted query is being actively classified by the backend */
isClassifying: boolean;
/** Single state variable encoding both the query lifecycle phase and (when idle) the user's mode selection. */
state: QueryState;
/** Update the app mode. Only takes effect when idle. No-op in CE or when search is unavailable. */
setAppMode: (mode: AppMode) => void;
/** Search results (empty if chat or not yet searched) */
searchResults: SearchDocWithContent[];
/** Document IDs selected by the LLM as most relevant */
@@ -31,8 +38,8 @@ export interface QueryControllerValue {
}
export const QueryControllerContext = createContext<QueryControllerValue>({
classification: null,
isClassifying: false,
state: { phase: "idle", appMode: "chat" },
setAppMode: () => undefined,
searchResults: [],
llmSelectedDocIds: null,
error: null,

View File

@@ -72,7 +72,6 @@ import { eeGated } from "@/ce";
import EESearchUI from "@/ee/sections/SearchUI";
const SearchUI = eeGated(EESearchUI);
import { motion, AnimatePresence } from "motion/react";
import { useAppMode } from "@/providers/AppModeProvider";
interface FadeProps {
show: boolean;
@@ -129,7 +128,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
type: "success",
},
});
const { setAppMode } = useAppMode();
const searchParams = useSearchParams();
// Use SWR hooks for data fetching
@@ -485,7 +483,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
finishOnboarding,
]
);
const { submit: submitQuery, classification } = useQueryController();
const { submit: submitQuery, state, setAppMode } = useQueryController();
const defaultAppMode =
(user?.preferences?.default_app_mode?.toLowerCase() as "chat" | "search") ??
@@ -493,12 +491,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
const isNewSession = appFocus.isNewSession();
const isSearch =
state.phase === "searching" || state.phase === "search-results";
// 1. Reset the app-mode back to the user's default when navigating back to the "New Sessions" tab.
// 2. If we're navigating away from the "New Session" tab after performing a search, we reset the app-input-bar.
useEffect(() => {
if (isNewSession) setAppMode(defaultAppMode);
if (!isNewSession && classification === "search") resetInputBar();
}, [isNewSession, defaultAppMode, classification, resetInputBar, setAppMode]);
if (!isNewSession && isSearch) resetInputBar();
}, [isNewSession, defaultAppMode, isSearch, resetInputBar, setAppMode]);
const handleSearchDocumentClick = useCallback(
(doc: MinimalOnyxDocument) => setPresentingDocument(doc),
@@ -607,7 +608,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
const hasStarterMessages = (liveAgent?.starter_messages?.length ?? 0) > 0;
const isSearch = classification === "search";
const gridStyle = {
gridTemplateColumns: "1fr",
gridTemplateRows: isSearch
@@ -735,7 +735,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
<Fade
show={
(appFocus.isNewSession() || appFocus.isAgent()) &&
!classification
(state.phase === "idle" || state.phase === "classifying")
}
className="w-full flex-1 flex flex-col items-center justify-end"
>
@@ -764,7 +764,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
{/* OnboardingUI */}
{(appFocus.isNewSession() || appFocus.isAgent()) &&
!classification &&
(state.phase === "idle" ||
state.phase === "classifying") &&
(showOnboarding || !user?.personalization?.name) &&
!onboardingDismissed && (
<OnboardingFlow
@@ -799,7 +800,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
<div
className={cn(
"transition-all duration-150 ease-in-out overflow-hidden",
classification === "search" ? "h-[14px]" : "h-0"
isSearch ? "h-[14px]" : "h-0"
)}
/>
<AppInputBar

View File

@@ -19,7 +19,6 @@ import useCCPairs from "@/hooks/useCCPairs";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import { ChatState } from "@/app/app/interfaces";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import { useAppMode } from "@/providers/AppModeProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { cn, isImageFile } from "@/lib/utils";
import { Disabled } from "@opal/core";
@@ -120,7 +119,10 @@ const AppInputBar = React.memo(
const filesContentRef = useRef<HTMLDivElement>(null);
const containerRef = useRef<HTMLDivElement>(null);
const { user } = useUser();
const { isClassifying, classification } = useQueryController();
const { state } = useQueryController();
const isClassifying = state.phase === "classifying";
const isSearchActive =
state.phase === "searching" || state.phase === "search-results";
// Expose reset and focus methods to parent via ref
React.useImperativeHandle(ref, () => ({
@@ -140,12 +142,10 @@ const AppInputBar = React.memo(
setMessage(initialMessage);
}
}, [initialMessage]);
const { appMode } = useAppMode();
const appFocus = useAppFocus();
const appMode = state.phase === "idle" ? state.appMode : undefined;
const isSearchMode =
(appFocus.isNewSession() && appMode === "search") ||
classification === "search";
(appFocus.isNewSession() && appMode === "search") || isSearchActive;
const { forcedToolIds, setForcedToolIds } = useForcedTools();
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =

View File

@@ -77,7 +77,6 @@ import { Notification, NotificationType } from "@/interfaces/settings";
import { errorHandlingFetcher } from "@/lib/fetcher";
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
import ChatSearchCommandMenu from "@/sections/sidebar/ChatSearchCommandMenu";
import { useAppMode } from "@/providers/AppModeProvider";
import { useQueryController } from "@/providers/QueryControllerProvider";
// Visible-agents = pinned-agents + current-agent (if current-agent not in pinned-agents)
@@ -206,8 +205,7 @@ const MemoizedAppSidebarInner = memo(
const combinedSettings = useSettingsContext();
const posthog = usePostHog();
const { newTenantInfo, invitationInfo } = useModalContext();
const { setAppMode } = useAppMode();
const { reset } = useQueryController();
const { setAppMode, reset } = useQueryController();
// Use SWR hooks for data fetching
const {