mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-11 02:32:43 +00:00
Compare commits
4 Commits
v3.0.0-clo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36196373a8 | ||
|
|
533aa8eff8 | ||
|
|
ecbb267f80 | ||
|
|
66023dbb6d |
@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -369,9 +370,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
models: list[SyncModelEntry],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -379,7 +380,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -393,21 +394,20 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
if model.name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
if model.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
model_name=model.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -406,7 +406,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.4.3
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
models = [
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Tests for LLM model fetch endpoints.
|
||||
|
||||
These tests verify the full request/response flow for fetching models
|
||||
from dynamic providers (Ollama, OpenRouter), including the
|
||||
from dynamic providers (Ollama, OpenRouter, Litellm), including the
|
||||
sync-to-DB behavior when provider_name is specified.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetLitellmAvailableModels:
|
||||
"""Tests for the Litellm proxy model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response(self) -> dict:
|
||||
"""Mock response from Litellm /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1700000001,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "gemini-pro",
|
||||
"object": "model",
|
||||
"created": 1700000002,
|
||||
"owned_by": "google",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted model list."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
|
||||
|
||||
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that provider_name and model_name are correctly extracted."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
gpt = next(r for r in results if r.model_name == "gpt-4o")
|
||||
assert gpt.provider_name == "openai"
|
||||
|
||||
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
|
||||
assert claude.provider_name == "anthropic"
|
||||
|
||||
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that results are alphabetically sorted by model_name."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
model_names = [r.model_name for r in results]
|
||||
assert model_names == sorted(model_names, key=str.lower)
|
||||
|
||||
def test_empty_data_raises_onyx_error(self) -> None:
|
||||
"""Test that empty model list raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No models found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_missing_data_key_raises_onyx_error(self) -> None:
|
||||
"""Test that response without 'data' key raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_skips_unparseable_entries(self) -> None:
|
||||
"""Test that malformed model entries are skipped without failing."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_with_bad_entry = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
# Missing required fields
|
||||
{"bad_field": "bad_value"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_with_bad_entry
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].model_name == "gpt-4o"
|
||||
|
||||
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
|
||||
"""Test that OnyxError is raised when all entries fail to parse."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_all_bad = {
|
||||
"data": [
|
||||
{"bad_field": "bad_value"},
|
||||
{"another_bad": 123},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_all_bad
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No compatible models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_api_base_trailing_slash_handled(self) -> None:
|
||||
"""Test that trailing slashes in api_base are handled correctly."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_litellm_response = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000/",
|
||||
api_key="test-key",
|
||||
)
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should call /v1/models without double slashes
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://localhost:4000/v1/models"
|
||||
|
||||
def test_connection_failure_raises_onyx_error(self) -> None:
|
||||
"""Test that connection failures are wrapped in OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
"""Test that a 401 response raises OnyxError with authentication message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="bad-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Authentication failed"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_404_raises_not_found_error(self) -> None:
|
||||
"""Test that a 404 response raises OnyxError with endpoint not found message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -153,7 +153,7 @@ dev = [
|
||||
"pytest-repeat==0.9.4",
|
||||
"pytest-xdist==3.8.0",
|
||||
"pytest==8.3.5",
|
||||
"release-tag==0.4.3",
|
||||
"release-tag==0.5.2",
|
||||
"reorder-python-imports-black==3.14.0",
|
||||
"ruff==0.12.0",
|
||||
"types-beautifulsoup4==4.12.0.3",
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -4485,7 +4485,7 @@ requires-dist = [
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
@@ -6338,16 +6338,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "release-tag"
|
||||
version = "0.4.3"
|
||||
version = "0.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Button } from "@opal/components";
|
||||
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import type { AppMode } from "@/providers/QueryControllerProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
@@ -58,15 +58,15 @@ const footerMarkdownComponents = {
|
||||
*/
|
||||
export default function NRFChrome() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const { state, setAppMode } = useQueryController();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
const [modePopoverOpen, setModePopoverOpen] = useState(false);
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
const effectiveMode: AppMode =
|
||||
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
|
||||
|
||||
const customFooterContent =
|
||||
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
|
||||
@@ -78,7 +78,7 @@ export default function NRFChrome() {
|
||||
isPaidEnterpriseFeaturesEnabled &&
|
||||
settings.isSearchModeAvailable &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification;
|
||||
state.phase === "idle";
|
||||
|
||||
const showHeader = isMobile || showModeToggle;
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
const isStreaming = currentChatState === "streaming";
|
||||
|
||||
// Query controller for search/chat classification (EE feature)
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
const { submit: submitQuery, state } = useQueryController();
|
||||
|
||||
// Determine if retrieval (search) is enabled based on the agent
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
@@ -186,7 +186,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
}, [liveAgent]);
|
||||
|
||||
// Check if we're in search mode
|
||||
const isSearch = classification === "search";
|
||||
const isSearch =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// Anchor for scroll positioning (matches ChatPage pattern)
|
||||
const anchorMessage = messageHistory.at(-2) ?? messageHistory[0];
|
||||
@@ -317,7 +318,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
};
|
||||
|
||||
// Use submitQuery which will classify the query and either:
|
||||
// - Route to search (sets classification to "search" and shows SearchUI)
|
||||
// - Route to search (sets phase to "searching"/"search-results" and shows SearchUI)
|
||||
// - Route to chat (calls onChat callback)
|
||||
await submitQuery(submittedMessage, onChat);
|
||||
},
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useCallback, useEffect } from "react";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { AppModeContext, AppMode } from "@/providers/AppModeProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
export interface AppModeProviderProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider for application mode (Search/Chat).
|
||||
*
|
||||
* This controls how user queries are handled:
|
||||
* - **search**: Forces search mode - quick document lookup
|
||||
* - **chat**: Forces chat mode - conversation with follow-up questions
|
||||
*
|
||||
* The initial mode is read from the user's persisted `default_app_mode` preference.
|
||||
* When search mode is unavailable (admin setting or no connectors), the mode is locked to "chat".
|
||||
*/
|
||||
export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { user } = useUser();
|
||||
const { isSearchModeAvailable } = useSettingsContext();
|
||||
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
const [appMode, setAppModeState] = useState<AppMode>("chat");
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) {
|
||||
setAppModeState("chat");
|
||||
return;
|
||||
}
|
||||
|
||||
if (persistedMode) {
|
||||
setAppModeState(persistedMode.toLowerCase() as AppMode);
|
||||
}
|
||||
}, [isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) return;
|
||||
setAppModeState(mode);
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable]
|
||||
);
|
||||
|
||||
return (
|
||||
<AppModeContext.Provider value={{ appMode, setAppMode }}>
|
||||
{children}
|
||||
</AppModeContext.Provider>
|
||||
);
|
||||
}
|
||||
@@ -8,14 +8,15 @@ import {
|
||||
SearchFullResponse,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { classifyQuery, searchDocuments } from "@/ee/lib/search/svc";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
QueryControllerContext,
|
||||
QueryClassification,
|
||||
QueryControllerValue,
|
||||
QueryState,
|
||||
AppMode,
|
||||
} from "@/providers/QueryControllerProvider";
|
||||
|
||||
interface QueryControllerProviderProps {
|
||||
@@ -25,19 +26,53 @@ interface QueryControllerProviderProps {
|
||||
export function QueryControllerProvider({
|
||||
children,
|
||||
}: QueryControllerProviderProps) {
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable: searchUiEnabled } = settings;
|
||||
const { user } = useUser();
|
||||
|
||||
// Query state
|
||||
// ── Merged query state (discriminated union) ──────────────────────────
|
||||
const [state, setState] = useState<QueryState>({
|
||||
phase: "idle",
|
||||
appMode: "chat",
|
||||
});
|
||||
|
||||
// Persistent app-mode preference — survives phase transitions and is
|
||||
// used to restore the correct mode when resetting back to idle.
|
||||
const appModeRef = useRef<AppMode>("chat");
|
||||
|
||||
// ── App mode sync from user preferences ───────────────────────────────
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
|
||||
useEffect(() => {
|
||||
let mode: AppMode = "chat";
|
||||
if (isPaidEnterpriseFeaturesEnabled && searchUiEnabled && persistedMode) {
|
||||
const lower = persistedMode.toLowerCase();
|
||||
mode = (["auto", "search", "chat"] as const).includes(lower as AppMode)
|
||||
? (lower as AppMode)
|
||||
: "chat";
|
||||
}
|
||||
appModeRef.current = mode;
|
||||
setState((prev) =>
|
||||
prev.phase === "idle" ? { phase: "idle", appMode: mode } : prev
|
||||
);
|
||||
}, [isPaidEnterpriseFeaturesEnabled, searchUiEnabled, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !searchUiEnabled) return;
|
||||
setState((prev) => {
|
||||
if (prev.phase !== "idle") return prev;
|
||||
appModeRef.current = mode;
|
||||
return { phase: "idle", appMode: mode };
|
||||
});
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, searchUiEnabled]
|
||||
);
|
||||
|
||||
// ── Ancillary state ───────────────────────────────────────────────────
|
||||
const [query, setQuery] = useState<string | null>(null);
|
||||
const [classification, setClassification] =
|
||||
useState<QueryClassification>(null);
|
||||
const [isClassifying, setIsClassifying] = useState(false);
|
||||
|
||||
// Search state
|
||||
const [searchResults, setSearchResults] = useState<SearchDocWithContent[]>(
|
||||
[]
|
||||
);
|
||||
@@ -51,7 +86,7 @@ export function QueryControllerProvider({
|
||||
const searchAbortRef = useRef<AbortController | null>(null);
|
||||
|
||||
/**
|
||||
* Perform document search
|
||||
* Perform document search (pure data-fetching, no phase side effects)
|
||||
*/
|
||||
const performSearch = useCallback(
|
||||
async (searchQuery: string, filters?: BaseFilters): Promise<void> => {
|
||||
@@ -85,19 +120,15 @@ export function QueryControllerProvider({
|
||||
setLlmSelectedDocIds(response.llm_selected_doc_ids ?? null);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
return;
|
||||
throw err;
|
||||
}
|
||||
|
||||
setError("Document search failed. Please try again.");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
} finally {
|
||||
// After we've performed a search, we automatically switch to "search" mode.
|
||||
// This is a "sticky" implementation; on purpose.
|
||||
setAppMode("search");
|
||||
}
|
||||
},
|
||||
[setAppMode]
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -112,8 +143,6 @@ export function QueryControllerProvider({
|
||||
const controller = new AbortController();
|
||||
classifyAbortRef.current = controller;
|
||||
|
||||
setIsClassifying(true);
|
||||
|
||||
try {
|
||||
const response: SearchFlowClassificationResponse = await classifyQuery(
|
||||
classifyQueryText,
|
||||
@@ -129,8 +158,6 @@ export function QueryControllerProvider({
|
||||
|
||||
setError("Query classification failed. Falling back to chat.");
|
||||
return "chat";
|
||||
} finally {
|
||||
setIsClassifying(false);
|
||||
}
|
||||
},
|
||||
[]
|
||||
@@ -148,62 +175,51 @@ export function QueryControllerProvider({
|
||||
setQuery(submitQuery);
|
||||
setError(null);
|
||||
|
||||
// 1.
|
||||
// We always route through chat if we're not Enterprise Enabled.
|
||||
//
|
||||
// 2.
|
||||
// We always route through chat if the admin has disabled the Search UI.
|
||||
//
|
||||
// 3.
|
||||
// We only go down the classification route if we're in the "New Session" tab.
|
||||
// Everywhere else, we always use the chat-flow.
|
||||
//
|
||||
// 4.
|
||||
// If we're in the "New Session" tab and the app-mode is "Chat", we continue with the chat-flow anyways.
|
||||
const currentAppMode = appModeRef.current;
|
||||
|
||||
// Always route through chat if:
|
||||
// 1. Not Enterprise Enabled
|
||||
// 2. Admin has disabled the Search UI
|
||||
// 3. Not in the "New Session" tab
|
||||
// 4. In "New Session" tab but app-mode is "Chat"
|
||||
if (
|
||||
!isPaidEnterpriseFeaturesEnabled ||
|
||||
!searchUiEnabled ||
|
||||
!appFocus.isNewSession() ||
|
||||
appMode === "chat"
|
||||
currentAppMode === "chat"
|
||||
) {
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
return;
|
||||
}
|
||||
|
||||
if (appMode === "search") {
|
||||
await performSearch(submitQuery, filters);
|
||||
setClassification("search");
|
||||
// Search mode: immediately show SearchUI with loading state
|
||||
if (currentAppMode === "search") {
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(submitQuery, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
return;
|
||||
}
|
||||
|
||||
// # Note (@raunakab)
|
||||
//
|
||||
// Interestingly enough, for search, we do:
|
||||
// 1. setClassification("search")
|
||||
// 2. performSearch
|
||||
//
|
||||
// But for chat, we do:
|
||||
// 1. performChat
|
||||
// 2. setClassification("chat")
|
||||
//
|
||||
// The ChatUI has a nice loading UI, so it's fine for us to prematurely set the
|
||||
// classification-state before the chat has finished loading.
|
||||
//
|
||||
// However, the SearchUI does not. Prematurely setting the classification-state
|
||||
// will lead to a slightly ugly UI.
|
||||
|
||||
// Auto mode: classify first, then route
|
||||
setState({ phase: "classifying" });
|
||||
try {
|
||||
const result = await performClassification(submitQuery);
|
||||
|
||||
if (result === "search") {
|
||||
setState({ phase: "searching" });
|
||||
await performSearch(submitQuery, filters);
|
||||
setClassification("search");
|
||||
setState({ phase: "search-results" });
|
||||
appModeRef.current = "search";
|
||||
} else {
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
@@ -213,14 +229,13 @@ export function QueryControllerProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
}
|
||||
},
|
||||
[
|
||||
appMode,
|
||||
appFocus,
|
||||
performClassification,
|
||||
performSearch,
|
||||
@@ -235,7 +250,14 @@ export function QueryControllerProvider({
|
||||
const refineSearch = useCallback(
|
||||
async (filters: BaseFilters): Promise<void> => {
|
||||
if (!query) return;
|
||||
await performSearch(query, filters);
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(query, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
},
|
||||
[query, performSearch]
|
||||
);
|
||||
@@ -254,7 +276,7 @@ export function QueryControllerProvider({
|
||||
}
|
||||
|
||||
setQuery(null);
|
||||
setClassification(null);
|
||||
setState({ phase: "idle", appMode: appModeRef.current });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
setError(null);
|
||||
@@ -262,8 +284,8 @@ export function QueryControllerProvider({
|
||||
|
||||
const value: QueryControllerValue = useMemo(
|
||||
() => ({
|
||||
classification,
|
||||
isClassifying,
|
||||
state,
|
||||
setAppMode,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -272,8 +294,8 @@ export function QueryControllerProvider({
|
||||
reset,
|
||||
}),
|
||||
[
|
||||
classification,
|
||||
isClassifying,
|
||||
state,
|
||||
setAppMode,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -283,7 +305,7 @@ export function QueryControllerProvider({
|
||||
]
|
||||
);
|
||||
|
||||
// Sync classification state with navigation context
|
||||
// Sync state with navigation context
|
||||
useEffect(reset, [appFocus, reset]);
|
||||
|
||||
return (
|
||||
|
||||
@@ -56,7 +56,7 @@ export default function SearchCard({
|
||||
|
||||
return (
|
||||
<Interactive.Stateless onClick={handleClick} prominence="secondary">
|
||||
<Interactive.Container heightVariant="fit">
|
||||
<Interactive.Container heightVariant="fit" widthVariant="full">
|
||||
<Section alignItems="start" gap={0} padding={0.25}>
|
||||
{/* Title Row */}
|
||||
<Section
|
||||
|
||||
@@ -18,16 +18,17 @@ import { getTimeFilterDate, TimeFilter } from "@/lib/time";
|
||||
import useTags from "@/hooks/useTags";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { SvgCheck, SvgClock, SvgTag } from "@opal/icons";
|
||||
import FilterButton from "@/refresh-components/buttons/FilterButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { LineItemButton } from "@opal/components";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
@@ -51,22 +52,17 @@ const TIME_FILTER_OPTIONS: { value: TimeFilter; label: string }[] = [
|
||||
{ value: "year", label: "Past year" },
|
||||
];
|
||||
|
||||
// ============================================================================
|
||||
// SearchResults Component (default export)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Component for displaying search results with source filter sidebar.
|
||||
*/
|
||||
export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
// Available tags from backend
|
||||
const { tags: availableTags } = useTags();
|
||||
const {
|
||||
state,
|
||||
searchResults: results,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
refineSearch: onRefineSearch,
|
||||
} = useQueryController();
|
||||
|
||||
const prevErrorRef = useRef<string | null>(null);
|
||||
|
||||
// Show a toast notification when a new error occurs
|
||||
@@ -197,6 +193,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
|
||||
const showEmpty = !error && results.length === 0;
|
||||
|
||||
// Show a centered spinner while search is in-flight (after all hooks)
|
||||
if (state.phase === "searching") {
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex items-center justify-center">
|
||||
<SimpleLoader />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex flex-col gap-3">
|
||||
{/* ── Top row: Filters + Result count ── */}
|
||||
@@ -226,18 +231,19 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<Popover.Content align="start" width="md">
|
||||
<PopoverMenu>
|
||||
{TIME_FILTER_OPTIONS.map((opt) => (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={opt.value}
|
||||
onClick={() => {
|
||||
setTimeFilter(opt.value);
|
||||
setTimeFilterOpen(false);
|
||||
onRefineSearch(buildFilters({ time: opt.value }));
|
||||
}}
|
||||
selected={timeFilter === opt.value}
|
||||
state={timeFilter === opt.value ? "selected" : "empty"}
|
||||
icon={timeFilter === opt.value ? SvgCheck : SvgClock}
|
||||
>
|
||||
{opt.label}
|
||||
</LineItem>
|
||||
title={opt.label}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
))}
|
||||
</PopoverMenu>
|
||||
</Popover.Content>
|
||||
@@ -278,7 +284,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
t.tag_value === tag.tag_value
|
||||
);
|
||||
return (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={`${tag.tag_key}=${tag.tag_value}`}
|
||||
onClick={() => {
|
||||
const next = isSelected
|
||||
@@ -291,11 +297,12 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
setSelectedTags(next);
|
||||
onRefineSearch(buildFilters({ tags: next }));
|
||||
}}
|
||||
selected={isSelected}
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
icon={isSelected ? SvgCheck : SvgTag}
|
||||
>
|
||||
{tag.tag_value}
|
||||
</LineItem>
|
||||
title={tag.tag_value}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</PopoverMenu>
|
||||
@@ -357,7 +364,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<div className="flex-1 min-h-0 overflow-y-auto flex flex-col gap-4 px-1">
|
||||
<Section gap={0.25} height="fit">
|
||||
{sourcesWithMeta.map(({ source, meta, count }) => (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={source}
|
||||
icon={(props) => (
|
||||
<SourceIcon
|
||||
@@ -367,12 +374,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
/>
|
||||
)}
|
||||
onClick={() => handleSourceToggle(source)}
|
||||
selected={selectedSources.includes(source)}
|
||||
emphasized
|
||||
state={
|
||||
selectedSources.includes(source) ? "selected" : "empty"
|
||||
}
|
||||
title={meta.displayName}
|
||||
selectVariant="select-heavy"
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={<Text text03>{count}</Text>}
|
||||
>
|
||||
{meta.displayName}
|
||||
</LineItem>
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
</div>
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//
|
||||
// This is useful in determining what `SidebarTab` should be active, for example.
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
|
||||
@@ -66,31 +67,25 @@ export default function useAppFocus(): AppFocus {
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Check if we're viewing a shared chat
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
|
||||
// Check if we're on the user settings page
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
|
||||
// Check if we're on the agents page
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
|
||||
// Check search params for chat, agent, or project
|
||||
const chatId = searchParams.get(SEARCH_PARAM_NAMES.CHAT_ID);
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
|
||||
const agentId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
|
||||
const projectId = searchParams.get(SEARCH_PARAM_NAMES.PROJECT_ID);
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
|
||||
// No search params means we're on a new session
|
||||
return new AppFocus("new-session");
|
||||
// Memoize on the values that determine which AppFocus is constructed.
|
||||
// AppFocus is immutable, so same inputs → same instance.
|
||||
return useMemo(() => {
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
return new AppFocus("new-session");
|
||||
}, [pathname, chatId, agentId, projectId]);
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ import {
|
||||
} from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import type { AppMode } from "@/providers/QueryControllerProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
@@ -82,7 +82,7 @@ import useBrowserInfo from "@/hooks/useBrowserInfo";
|
||||
*/
|
||||
function Header() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const { state, setAppMode } = useQueryController();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
@@ -108,7 +108,6 @@ function Header() {
|
||||
useChatSessions();
|
||||
const router = useRouter();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
|
||||
const customHeaderContent =
|
||||
settings?.enterpriseSettings?.custom_header_content;
|
||||
@@ -117,7 +116,8 @@ function Header() {
|
||||
// without this content still use.
|
||||
const pageWithHeaderContent = appFocus.isChat() || appFocus.isNewSession();
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
const effectiveMode: AppMode =
|
||||
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
|
||||
|
||||
const availableProjects = useMemo(() => {
|
||||
if (!projects) return [];
|
||||
@@ -323,7 +323,7 @@ function Header() {
|
||||
{isPaidEnterpriseFeaturesEnabled &&
|
||||
settings.isSearchModeAvailable &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification && (
|
||||
state.phase === "idle" && (
|
||||
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext } from "react";
|
||||
import { eeGated } from "@/ce";
|
||||
import { AppModeProvider as EEAppModeProvider } from "@/ee/providers/AppModeProvider";
|
||||
|
||||
export type AppMode = "auto" | "search" | "chat";
|
||||
|
||||
interface AppModeContextValue {
|
||||
appMode: AppMode;
|
||||
setAppMode: (mode: AppMode) => void;
|
||||
}
|
||||
|
||||
export const AppModeContext = createContext<AppModeContextValue>({
|
||||
appMode: "chat",
|
||||
setAppMode: () => undefined,
|
||||
});
|
||||
|
||||
export function useAppMode(): AppModeContextValue {
|
||||
return useContext(AppModeContext);
|
||||
}
|
||||
|
||||
export const AppModeProvider = eeGated(EEAppModeProvider);
|
||||
@@ -24,7 +24,7 @@
|
||||
* 4. **ProviderContextProvider** - LLM provider configuration
|
||||
* 5. **ModalProvider** - Global modal state management
|
||||
* 6. **AppSidebarProvider** - Sidebar open/closed state
|
||||
* 7. **AppModeProvider** - Search/Chat mode selection
|
||||
* 7. **QueryControllerProvider** - Search/Chat mode + query lifecycle
|
||||
*
|
||||
* ## Usage
|
||||
*
|
||||
@@ -40,7 +40,7 @@
|
||||
* - `useSettingsContext()` - from SettingsProvider
|
||||
* - `useUser()` - from UserProvider
|
||||
* - `useAppBackground()` - from AppBackgroundProvider
|
||||
* - `useAppMode()` - from AppModeProvider
|
||||
* - `useQueryController()` - from QueryControllerProvider (includes appMode)
|
||||
* - etc.
|
||||
*
|
||||
* @TODO(@raunakab): The providers wrapped by this component are currently
|
||||
@@ -65,7 +65,6 @@ import { User } from "@/lib/types";
|
||||
import { ModalProvider } from "@/components/context/ModalContext";
|
||||
import { AuthTypeMetadata } from "@/lib/userSS";
|
||||
import { AppSidebarProvider } from "@/providers/AppSidebarProvider";
|
||||
import { AppModeProvider } from "@/providers/AppModeProvider";
|
||||
import { AppBackgroundProvider } from "@/providers/AppBackgroundProvider";
|
||||
import { QueryControllerProvider } from "@/providers/QueryControllerProvider";
|
||||
import ToastProvider from "@/providers/ToastProvider";
|
||||
@@ -96,11 +95,9 @@ export default function AppProvider({
|
||||
<ProviderContextProvider>
|
||||
<ModalProvider user={user}>
|
||||
<AppSidebarProvider folded={!!folded}>
|
||||
<AppModeProvider>
|
||||
<QueryControllerProvider>
|
||||
<ToastProvider>{children}</ToastProvider>
|
||||
</QueryControllerProvider>
|
||||
</AppModeProvider>
|
||||
<QueryControllerProvider>
|
||||
<ToastProvider>{children}</ToastProvider>
|
||||
</QueryControllerProvider>
|
||||
</AppSidebarProvider>
|
||||
</ModalProvider>
|
||||
</ProviderContextProvider>
|
||||
|
||||
@@ -5,13 +5,20 @@ import { eeGated } from "@/ce";
|
||||
import { QueryControllerProvider as EEQueryControllerProvider } from "@/ee/providers/QueryControllerProvider";
|
||||
import { SearchDocWithContent, BaseFilters } from "@/lib/search/interfaces";
|
||||
|
||||
export type QueryClassification = "search" | "chat" | null;
|
||||
export type AppMode = "auto" | "search" | "chat";
|
||||
|
||||
export type QueryState =
|
||||
| { phase: "idle"; appMode: AppMode }
|
||||
| { phase: "classifying" }
|
||||
| { phase: "searching" }
|
||||
| { phase: "search-results" }
|
||||
| { phase: "chat" };
|
||||
|
||||
export interface QueryControllerValue {
|
||||
/** Classification state: null (idle), "search", or "chat" */
|
||||
classification: QueryClassification;
|
||||
/** Whether or not the currently submitted query is being actively classified by the backend */
|
||||
isClassifying: boolean;
|
||||
/** Single state variable encoding both the query lifecycle phase and (when idle) the user's mode selection. */
|
||||
state: QueryState;
|
||||
/** Update the app mode. Only takes effect when idle. No-op in CE or when search is unavailable. */
|
||||
setAppMode: (mode: AppMode) => void;
|
||||
/** Search results (empty if chat or not yet searched) */
|
||||
searchResults: SearchDocWithContent[];
|
||||
/** Document IDs selected by the LLM as most relevant */
|
||||
@@ -31,8 +38,8 @@ export interface QueryControllerValue {
|
||||
}
|
||||
|
||||
export const QueryControllerContext = createContext<QueryControllerValue>({
|
||||
classification: null,
|
||||
isClassifying: false,
|
||||
state: { phase: "idle", appMode: "chat" },
|
||||
setAppMode: () => undefined,
|
||||
searchResults: [],
|
||||
llmSelectedDocIds: null,
|
||||
error: null,
|
||||
|
||||
@@ -72,7 +72,6 @@ import { eeGated } from "@/ce";
|
||||
import EESearchUI from "@/ee/sections/SearchUI";
|
||||
const SearchUI = eeGated(EESearchUI);
|
||||
import { motion, AnimatePresence } from "motion/react";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
|
||||
interface FadeProps {
|
||||
show: boolean;
|
||||
@@ -129,7 +128,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
type: "success",
|
||||
},
|
||||
});
|
||||
const { setAppMode } = useAppMode();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Use SWR hooks for data fetching
|
||||
@@ -485,7 +483,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
finishOnboarding,
|
||||
]
|
||||
);
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
const { submit: submitQuery, state, setAppMode } = useQueryController();
|
||||
|
||||
const defaultAppMode =
|
||||
(user?.preferences?.default_app_mode?.toLowerCase() as "chat" | "search") ??
|
||||
@@ -493,12 +491,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const isNewSession = appFocus.isNewSession();
|
||||
|
||||
const isSearch =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// 1. Reset the app-mode back to the user's default when navigating back to the "New Sessions" tab.
|
||||
// 2. If we're navigating away from the "New Session" tab after performing a search, we reset the app-input-bar.
|
||||
useEffect(() => {
|
||||
if (isNewSession) setAppMode(defaultAppMode);
|
||||
if (!isNewSession && classification === "search") resetInputBar();
|
||||
}, [isNewSession, defaultAppMode, classification, resetInputBar, setAppMode]);
|
||||
if (!isNewSession && isSearch) resetInputBar();
|
||||
}, [isNewSession, defaultAppMode, isSearch, resetInputBar, setAppMode]);
|
||||
|
||||
const handleSearchDocumentClick = useCallback(
|
||||
(doc: MinimalOnyxDocument) => setPresentingDocument(doc),
|
||||
@@ -607,7 +608,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const hasStarterMessages = (liveAgent?.starter_messages?.length ?? 0) > 0;
|
||||
|
||||
const isSearch = classification === "search";
|
||||
const gridStyle = {
|
||||
gridTemplateColumns: "1fr",
|
||||
gridTemplateRows: isSearch
|
||||
@@ -735,7 +735,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<Fade
|
||||
show={
|
||||
(appFocus.isNewSession() || appFocus.isAgent()) &&
|
||||
!classification
|
||||
(state.phase === "idle" || state.phase === "classifying")
|
||||
}
|
||||
className="w-full flex-1 flex flex-col items-center justify-end"
|
||||
>
|
||||
@@ -764,7 +764,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
{/* OnboardingUI */}
|
||||
{(appFocus.isNewSession() || appFocus.isAgent()) &&
|
||||
!classification &&
|
||||
(state.phase === "idle" ||
|
||||
state.phase === "classifying") &&
|
||||
(showOnboarding || !user?.personalization?.name) &&
|
||||
!onboardingDismissed && (
|
||||
<OnboardingFlow
|
||||
@@ -799,7 +800,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<div
|
||||
className={cn(
|
||||
"transition-all duration-150 ease-in-out overflow-hidden",
|
||||
classification === "search" ? "h-[14px]" : "h-0"
|
||||
isSearch ? "h-[14px]" : "h-0"
|
||||
)}
|
||||
/>
|
||||
<AppInputBar
|
||||
|
||||
@@ -19,7 +19,6 @@ import useCCPairs from "@/hooks/useCCPairs";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { cn, isImageFile } from "@/lib/utils";
|
||||
import { Disabled } from "@opal/core";
|
||||
@@ -120,7 +119,10 @@ const AppInputBar = React.memo(
|
||||
const filesContentRef = useRef<HTMLDivElement>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const { user } = useUser();
|
||||
const { isClassifying, classification } = useQueryController();
|
||||
const { state } = useQueryController();
|
||||
const isClassifying = state.phase === "classifying";
|
||||
const isSearchActive =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// Expose reset and focus methods to parent via ref
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
@@ -140,12 +142,10 @@ const AppInputBar = React.memo(
|
||||
setMessage(initialMessage);
|
||||
}
|
||||
}, [initialMessage]);
|
||||
|
||||
const { appMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const appMode = state.phase === "idle" ? state.appMode : undefined;
|
||||
const isSearchMode =
|
||||
(appFocus.isNewSession() && appMode === "search") ||
|
||||
classification === "search";
|
||||
(appFocus.isNewSession() && appMode === "search") || isSearchActive;
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useForcedTools();
|
||||
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =
|
||||
|
||||
@@ -77,7 +77,6 @@ import { Notification, NotificationType } from "@/interfaces/settings";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
import ChatSearchCommandMenu from "@/sections/sidebar/ChatSearchCommandMenu";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
|
||||
// Visible-agents = pinned-agents + current-agent (if current-agent not in pinned-agents)
|
||||
@@ -206,8 +205,7 @@ const MemoizedAppSidebarInner = memo(
|
||||
const combinedSettings = useSettingsContext();
|
||||
const posthog = usePostHog();
|
||||
const { newTenantInfo, invitationInfo } = useModalContext();
|
||||
const { setAppMode } = useAppMode();
|
||||
const { reset } = useQueryController();
|
||||
const { setAppMode, reset } = useQueryController();
|
||||
|
||||
// Use SWR hooks for data fetching
|
||||
const {
|
||||
|
||||
Reference in New Issue
Block a user