mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-11 02:32:43 +00:00
Compare commits
15 Commits
v3.0.0-clo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36196373a8 | ||
|
|
533aa8eff8 | ||
|
|
ecbb267f80 | ||
|
|
66023dbb6d | ||
|
|
f97466e4de | ||
|
|
2cc8303e5f | ||
|
|
a92ff61f64 | ||
|
|
17551a907e | ||
|
|
9e42951fa4 | ||
|
|
dcb18c2411 | ||
|
|
2f628e39d3 | ||
|
|
fd200d46f8 | ||
|
|
ec7482619b | ||
|
|
9d1a357533 | ||
|
|
fbe823b551 |
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN"
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +127,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -314,6 +314,9 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
|
||||
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
|
||||
)
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
|
||||
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -258,6 +258,10 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# Track yielded hierarchy nodes by their raw_node_id (URLs) to avoid duplicates
|
||||
seen_hierarchy_node_raw_ids: set[str] = Field(default_factory=set)
|
||||
|
||||
# Track yielded document IDs to avoid processing the same document twice.
|
||||
# The Microsoft Graph delta API can return the same item on multiple pages.
|
||||
seen_document_ids: set[str] = Field(default_factory=set)
|
||||
|
||||
|
||||
class SharepointAuthMethod(Enum):
|
||||
CLIENT_SECRET = "client_secret"
|
||||
@@ -1557,6 +1561,7 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_id = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
@@ -2137,6 +2142,14 @@ class SharepointConnector(
|
||||
item_count = 0
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
|
||||
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
|
||||
logger.debug(
|
||||
f"Skipping duplicate document {driveitem.id} "
|
||||
f"({driveitem.name})"
|
||||
)
|
||||
continue
|
||||
|
||||
driveitem_extension = get_file_ext(driveitem.name)
|
||||
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
logger.warning(
|
||||
@@ -2189,11 +2202,13 @@ class SharepointConnector(
|
||||
|
||||
if isinstance(doc_or_failure, Document):
|
||||
if doc_or_failure.sections:
|
||||
checkpoint.seen_document_ids.add(doc_or_failure.id)
|
||||
yield doc_or_failure
|
||||
elif should_yield_if_empty:
|
||||
doc_or_failure.sections = [
|
||||
TextSection(link=driveitem.web_url, text="")
|
||||
]
|
||||
checkpoint.seen_document_ids.add(doc_or_failure.id)
|
||||
yield doc_or_failure
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -369,9 +370,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
models: list[SyncModelEntry],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -379,7 +380,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -393,21 +394,20 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
if model.name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
if model.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
model_name=model.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
|
||||
@@ -163,6 +163,8 @@ class _EncryptedBase(TypeDecorator):
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
# Must redeclare cache_ok in this child class since we explicitly redeclare _is_json
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
@@ -189,6 +191,7 @@ class EncryptedString(_EncryptedBase):
|
||||
|
||||
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
cache_ok = True
|
||||
_is_json: bool = True
|
||||
|
||||
def process_bind_param(
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# Default value for the maximum number of tokens a chunk can hold, if none is
|
||||
# specified when creating an index.
|
||||
from onyx.configs.app_configs import (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_MAX_CHUNK_SIZE = 512
|
||||
|
||||
# Size of the dynamic list used to consider elements during kNN graph creation.
|
||||
@@ -10,27 +15,43 @@ EF_CONSTRUCTION = 256
|
||||
# quality but increase memory footprint. Values typically range between 12 - 48.
|
||||
M = 32 # Set relatively high for better accuracy.
|
||||
|
||||
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
|
||||
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
|
||||
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
|
||||
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
|
||||
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
|
||||
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
|
||||
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
|
||||
# a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
|
||||
# When performing hybrid search, we need to consider more candidates than the
|
||||
# number of results to be returned. This is because the scoring is hybrid and
|
||||
# the results are reordered due to the hybrid scoring. Higher = more candidates
|
||||
# for hybrid fusion = better retrieval accuracy, but results in more computation
|
||||
# per query. Imagine a simple case with a single keyword query and a single
|
||||
# vector query and we want 10 final docs. If we only fetch 10 candidates from
|
||||
# each of keyword and vector, they would have to have perfect overlap to get a
|
||||
# good hybrid ranking for the 10 results. If we fetch 1000 candidates from each,
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
|
||||
else 750
|
||||
)
|
||||
|
||||
# Number of vectors to examine for top k neighbors for the HNSW method.
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
# method.
|
||||
# NOTE: "When creating a search query, you must specify k. If you provide both k
|
||||
# and ef_search, then the larger value is passed to the engine. If ef_search is
|
||||
# larger than k, you can provide the size parameter to limit the final number of
|
||||
# results to k." from
|
||||
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
|
||||
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
|
||||
# rather than an independent scoring component.
|
||||
# Since the titles are included in the contents, the embedding matches are
|
||||
# heavily downweighted as they act as a boost rather than an independent scoring
|
||||
# component.
|
||||
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
|
||||
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
|
||||
# Single keyword weight for both title and content (merged from former title
|
||||
# keyword + content keyword).
|
||||
SEARCH_KEYWORD_WEIGHT = 0.45
|
||||
|
||||
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
|
||||
# NOTE: It is critical that the order of these weights matches the order of the
|
||||
# sub-queries in the hybrid search.
|
||||
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
|
||||
@@ -433,12 +433,16 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
hidden=fields.hidden if fields else None,
|
||||
project_ids=(
|
||||
set(user_fields.user_projects)
|
||||
if user_fields and user_fields.user_projects
|
||||
# NOTE: Empty user_projects is semantically different from None
|
||||
# user_projects.
|
||||
if user_fields and user_fields.user_projects is not None
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
if user_fields and user_fields.personas
|
||||
# NOTE: Empty personas is semantically different from None
|
||||
# personas.
|
||||
if user_fields and user_fields.personas is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -255,8 +255,12 @@ class DocumentQuery:
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
# TODO(andrei, yuhong): We can tune this more dynamically based on
|
||||
# num_hits.
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector
|
||||
query_text, query_vector, vector_candidates=max_results_per_subquery
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
@@ -285,13 +289,16 @@ class DocumentQuery:
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Max results per subquery per shard before aggregation. Ensures keyword and vector
|
||||
# subqueries contribute equally to the candidate pool for hybrid fusion.
|
||||
# Max results per subquery per shard before aggregation. Ensures
|
||||
# keyword and vector subqueries contribute equally to the
|
||||
# candidate pool for hybrid fusion.
|
||||
# Sources:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
|
||||
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
|
||||
"pagination_depth": max_results_per_subquery,
|
||||
# Applied to all the sub-queries independently (this avoids
|
||||
# subqueries having a lot of results thrown out during
|
||||
# aggregation).
|
||||
# Sources:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
|
||||
@@ -374,9 +381,10 @@ class DocumentQuery:
|
||||
def _get_hybrid_search_subqueries(
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
# The default number of neighbors to consider for knn vector similarity search.
|
||||
# This is higher than the number of results because the scoring is hybrid.
|
||||
# for a detailed breakdown, see where the default value is set.
|
||||
# The default number of neighbors to consider for knn vector similarity
|
||||
# search. This is higher than the number of results because the scoring
|
||||
# is hybrid. For a detailed breakdown, see where the default value is
|
||||
# set.
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
@@ -400,20 +408,27 @@ class DocumentQuery:
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
NOTE: Each query is independent during the search phase, there is no backfilling of scores for missing query components.
|
||||
What this means is that if a document was a good vector match but did not show up for keyword, it gets a score of 0 for
|
||||
the keyword component of the hybrid scoring. This is not as bad as just disregarding a score though as there is
|
||||
normalization applied after. So really it is "increasing" the missing score compared to if it was included and the range
|
||||
was renormalized. This does however mean that between docs that have high scores for say the vector field, the keyword
|
||||
scores between them are completely ignored unless they also showed up in the keyword query as a reasonably high match.
|
||||
TLDR, this is a bit of unique funky behavior but it seems ok.
|
||||
NOTE: Each query is independent during the search phase, there is no
|
||||
backfilling of scores for missing query components. What this means is
|
||||
that if a document was a good vector match but did not show up for
|
||||
keyword, it gets a score of 0 for the keyword component of the hybrid
|
||||
scoring. This is not as bad as just disregarding a score though as there
|
||||
is normalization applied after. So really it is "increasing" the missing
|
||||
score compared to if it was included and the range was renormalized.
|
||||
This does however mean that between docs that have high scores for say
|
||||
the vector field, the keyword scores between them are completely ignored
|
||||
unless they also showed up in the keyword query as a reasonably high
|
||||
match. TLDR, this is a bit of unique funky behavior but it seems ok.
|
||||
|
||||
NOTE: Options considered and rejected:
|
||||
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
|
||||
and very low number of meaningful keywords (and a low ratio of keywords).
|
||||
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
|
||||
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
|
||||
less performant so not really any reason to do it.
|
||||
- minimum_should_match: Since it's hybrid search and users often provide
|
||||
semantic queries, there is often a lot of terms, and very low number
|
||||
of meaningful keywords (and a low ratio of keywords).
|
||||
- fuzziness AUTO: Typo tolerance (0/1/2 edit distance by term length).
|
||||
It's mostly for typos as the analyzer ("english" by default) already
|
||||
does some stemming and tokenization. In testing datasets, this makes
|
||||
recall slightly worse. It also is less performant so not really any
|
||||
reason to do it.
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -723,14 +738,13 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
@@ -758,9 +772,8 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
|
||||
@@ -690,9 +690,12 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
# NOTE: Empty user_projects is semantically different from None
|
||||
# user_projects.
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
# NOTE: Empty personas is semantically different from None personas.
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -406,7 +406,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.4.3
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
|
||||
@@ -0,0 +1,398 @@
|
||||
"""External dependency tests for the old DocumentIndex interface.
|
||||
|
||||
These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_available() -> Generator[None, None, None]:
|
||||
"""Verifies OpenSearch is running, fails the test if not."""
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
pytest.fail("OpenSearch is not available.")
|
||||
yield # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_index_name() -> Generator[str, None, None]:
|
||||
yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
"""Sets up tenant context for testing."""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield # Test runs here.
|
||||
finally:
|
||||
# Reset the tenant context after the test
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def httpx_client() -> Generator[httpx.Client, None, None]:
|
||||
client = get_vespa_http_client()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vespa_document_index(
|
||||
httpx_client: httpx.Client,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[VespaIndex, None, None]:
|
||||
vespa_index = VespaIndex(
|
||||
index_name=test_index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
backend_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
)
|
||||
with patch("os.getcwd", return_value=backend_dir):
|
||||
vespa_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
# Verify Vespa is running, fails the test if not. Try 90 seconds for testing
|
||||
# in CI. We have to do this here because this endpoint only becomes live
|
||||
# once we create an index.
|
||||
if not wait_for_vespa_with_timeout(wait_limit=90):
|
||||
pytest.fail("Vespa is not available.")
|
||||
|
||||
# Wait until the schema is actually ready for writes on content nodes. We
|
||||
# probe by attempting a PUT; 200 means the schema is live, 400 means not
|
||||
# yet. This is so scuffed but running the test is really flakey otherwise;
|
||||
# this is only temporary until we entirely move off of Vespa.
|
||||
probe_doc = {
|
||||
"fields": {
|
||||
"document_id": "__probe__",
|
||||
"chunk_id": 0,
|
||||
"blurb": "",
|
||||
"title": "",
|
||||
"skip_title": True,
|
||||
"content": "",
|
||||
"content_summary": "",
|
||||
"source_type": "file",
|
||||
"source_links": "null",
|
||||
"semantic_identifier": "",
|
||||
"section_continuation": False,
|
||||
"large_chunk_reference_ids": [],
|
||||
"metadata": "{}",
|
||||
"metadata_list": [],
|
||||
"metadata_suffix": "",
|
||||
"chunk_context": "",
|
||||
"doc_summary": "",
|
||||
"embeddings": {"full_chunk": [1.0] + [0.0] * 127},
|
||||
"access_control_list": {},
|
||||
"document_sets": {},
|
||||
"image_file_name": None,
|
||||
"user_project": [],
|
||||
"personas": [],
|
||||
"boost": 0.0,
|
||||
"aggregated_chunk_boost_factor": 0.0,
|
||||
"primary_owners": [],
|
||||
"secondary_owners": [],
|
||||
}
|
||||
}
|
||||
schema_ready = False
|
||||
probe_url = (
|
||||
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
|
||||
)
|
||||
for _ in range(60):
|
||||
resp = httpx_client.post(probe_url, json=probe_doc)
|
||||
if resp.status_code == 200:
|
||||
schema_ready = True
|
||||
# Clean up the probe document.
|
||||
httpx_client.delete(probe_url)
|
||||
break
|
||||
time.sleep(1)
|
||||
if not schema_ready:
|
||||
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
|
||||
|
||||
yield vespa_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_document_index(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
|
||||
opensearch_index = OpenSearchOldDocumentIndex(
|
||||
index_name=test_index_name,
|
||||
embedding_dim=128,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_name=None,
|
||||
secondary_embedding_dim=None,
|
||||
secondary_embedding_precision=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
opensearch_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
yield opensearch_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_indices(
|
||||
vespa_document_index: VespaIndex,
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex,
|
||||
) -> Generator[list[DocumentIndex], None, None]:
|
||||
# Ideally these are parametrized; doing so with pytest fixtures is tricky.
|
||||
yield [opensearch_document_index, vespa_document_index] # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chunks(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[list[DocMetadataAwareIndexChunk], None, None]:
|
||||
result = []
|
||||
chunk_count = 5
|
||||
doc_id = "test_doc"
|
||||
tenant_id = get_current_tenant_id()
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
document_sets: set[str] = set()
|
||||
user_project: list[int] = list()
|
||||
personas: list[int] = list()
|
||||
boost = 0
|
||||
blurb = "blurb"
|
||||
content = "content"
|
||||
title_prefix = ""
|
||||
doc_summary = ""
|
||||
chunk_context = ""
|
||||
title_embedding = [1.0] + [0] * 127
|
||||
# Full 0 vectors are not supported for cos similarity.
|
||||
embeddings = ChunkEmbedding(
|
||||
full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[]
|
||||
)
|
||||
source_document = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="semantic identifier",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[],
|
||||
metadata={},
|
||||
title="title",
|
||||
)
|
||||
metadata_suffix_keyword = ""
|
||||
image_file_id = None
|
||||
source_links: dict[int, str] = {0: ""}
|
||||
ancestor_hierarchy_node_ids: list[int] = []
|
||||
for i in range(chunk_count):
|
||||
result.append(
|
||||
DocMetadataAwareIndexChunk(
|
||||
tenant_id=tenant_id,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=0,
|
||||
ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids,
|
||||
embeddings=embeddings,
|
||||
title_embedding=title_embedding,
|
||||
source_document=source_document,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
metadata_suffix_semantic="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
chunk_id=i,
|
||||
blurb=blurb,
|
||||
content=content,
|
||||
source_links=source_links,
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=False,
|
||||
)
|
||||
)
|
||||
yield result # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_batch_params(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[IndexBatchParams, None, None]:
|
||||
# WARNING: doc_id_to_previous_chunk_cnt={"test_doc": 0} is hardcoded to 0,
|
||||
# which is only correct on the very first index call. The document_indices
|
||||
# fixture is scope="module", meaning the same OpenSearch and Vespa backends
|
||||
# persist across all test functions in this module. When a second test
|
||||
# function uses this fixture and calls document_index.index(...), the
|
||||
# backend already has 5 chunks for "test_doc" from the previous test run,
|
||||
# but the batch params still claim 0 prior chunks exist. This can lead to
|
||||
# orphaned/duplicate chunks that make subsequent assertions incorrect.
|
||||
# TODO: Whenever adding a second test, either change this or cleanup the
|
||||
# index between test cases.
|
||||
yield IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc": 5},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentIndexOld:
|
||||
"""Tests the old DocumentIndex interface."""
|
||||
|
||||
def test_update_single_can_clear_user_projects_and_personas(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
# This test case assumes all these chunks correspond to one document.
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that update_single can clear user_projects and personas.
|
||||
"""
|
||||
for document_index in document_indices:
|
||||
# Precondition.
|
||||
# Ensure there is some non-empty value for user project and
|
||||
# personas.
|
||||
for chunk in chunks:
|
||||
chunk.user_project = [1]
|
||||
chunk.personas = [2]
|
||||
document_index.index(chunks, index_batch_params)
|
||||
|
||||
# Ensure that we can get chunks as expected with filters.
|
||||
doc_id = chunks[0].source_document.id
|
||||
chunk_count = len(chunks)
|
||||
tenant_id = get_current_tenant_id()
|
||||
# We need to specify the chunk index range and specify
|
||||
# batch_retrieval=True below to trigger the codepath for Vespa's
|
||||
# search API, which uses the expected additive filtering for
|
||||
# project_id and persona_id. Otherwise we would use the codepath for
|
||||
# the visit API, which does not have this kind of filtering
|
||||
# implemented.
|
||||
chunk_request = VespaChunkRequest(
|
||||
document_id=doc_id, min_chunk_ind=0, max_chunk_ind=chunk_count - 1
|
||||
)
|
||||
project_persona_filters = IndexFilters(
|
||||
access_control_list=None,
|
||||
tenant_id=tenant_id,
|
||||
project_id=1,
|
||||
persona_id=2,
|
||||
# We need this even though none of the chunks belong to a
|
||||
# document set because project_id and persona_id are only
|
||||
# additive filters in the event the agent has knowledge scope;
|
||||
# if the agent does not, it is implied that it can see
|
||||
# everything it is allowed to.
|
||||
document_set=["1"],
|
||||
)
|
||||
# Not best practice here but the API for refreshing the index to
|
||||
# ensure that the latest data is present is not exposed in this
|
||||
# class and is not the same for Vespa and OpenSearch, so we just
|
||||
# tolerate a sleep for now. As a consequence the number of tests in
|
||||
# this suite should be small. We only need to tolerate this for as
|
||||
# long as we continue to use Vespa, we can consider exposing
|
||||
# something for OpenSearch later.
|
||||
time.sleep(1)
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request],
|
||||
filters=project_persona_filters,
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == chunk_count
|
||||
# Sort by chunk id to easily test if we have all chunks.
|
||||
for i, inference_chunk in enumerate(
|
||||
sorted(inference_chunks, key=lambda x: x.chunk_id)
|
||||
):
|
||||
assert inference_chunk.chunk_id == i
|
||||
assert inference_chunk.document_id == doc_id
|
||||
|
||||
# Under test.
|
||||
# Explicitly set empty fields here.
|
||||
user_fields = VespaDocumentUserFields(user_projects=[], personas=[])
|
||||
document_index.update_single(
|
||||
doc_id=doc_id,
|
||||
chunk_count=chunk_count,
|
||||
tenant_id=tenant_id,
|
||||
fields=None,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=tenant_id)
|
||||
# We should expect to get back all expected chunks with no filters.
|
||||
# Again, not best practice here.
|
||||
time.sleep(1)
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request], filters=filters, batch_retrieval=True
|
||||
)
|
||||
assert len(inference_chunks) == chunk_count
|
||||
# Sort by chunk id to easily test if we have all chunks.
|
||||
for i, inference_chunk in enumerate(
|
||||
sorted(inference_chunks, key=lambda x: x.chunk_id)
|
||||
):
|
||||
assert inference_chunk.chunk_id == i
|
||||
assert inference_chunk.document_id == doc_id
|
||||
# Now, we should expect to not get any chunks if we specify the user
|
||||
# project and personas filters.
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request],
|
||||
filters=project_persona_filters,
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
@@ -239,6 +239,8 @@ def full_deployment_setup() -> Generator[None, None, None]:
|
||||
NOTE: We deliberately duplicate this logic from
|
||||
backend/tests/external_dependency_unit/conftest.py because we need to set
|
||||
opensearch_available just for this module, not the entire test session.
|
||||
|
||||
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
|
||||
"""
|
||||
# Patch ENABLE_OPENSEARCH_INDEXING_FOR_ONYX just for this test because we
|
||||
# don't yet want that enabled for all tests.
|
||||
|
||||
@@ -6,6 +6,7 @@ Validates that:
|
||||
- Crash + resume skips already-processed pages
|
||||
- BFS (folder-scoped) drives process all items in one call
|
||||
- 410 Gone triggers a full-resync URL in the checkpoint
|
||||
- Duplicate document IDs across delta pages are deduplicated
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -457,3 +458,228 @@ class TestDeltaPageFetchFailure:
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestDeltaDuplicateDocumentDedup:
|
||||
"""The Microsoft Graph delta API can return the same item on multiple
|
||||
pages. Documents already yielded should be skipped via
|
||||
checkpoint.seen_document_ids."""
|
||||
|
||||
def test_duplicate_across_pages_is_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Item 'dup' appears on both page 1 and page 2. It should only be
|
||||
yielded once."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [_make_item("a"), _make_item("dup")], "https://next2"
|
||||
return [_make_item("dup"), _make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Page 1: yields a, dup
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["a", "dup"]
|
||||
assert "dup" in checkpoint.seen_document_ids
|
||||
|
||||
# Page 2: dup should be skipped, only b yielded
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["b"]
|
||||
|
||||
def test_duplicate_within_same_page_is_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the same item appears twice on a single delta page, only the
|
||||
first occurrence should be yielded."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("x"), _make_item("x"), _make_item("y")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["x", "y"]
|
||||
|
||||
def test_seen_ids_survive_checkpoint_serialization(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""seen_document_ids must survive JSON serialization so that
|
||||
dedup works across crash + resume."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [_make_item("a")], "https://next2"
|
||||
return [_make_item("a"), _make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Page 1
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
_, checkpoint = _consume_generator(gen)
|
||||
assert "a" in checkpoint.seen_document_ids
|
||||
|
||||
# Simulate crash: round-trip through JSON
|
||||
restored = SharepointConnectorCheckpoint.model_validate_json(
|
||||
checkpoint.model_dump_json()
|
||||
)
|
||||
assert "a" in restored.seen_document_ids
|
||||
|
||||
# Page 2 with restored checkpoint: 'a' should be skipped
|
||||
connector2 = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
gen = connector2._load_from_checkpoint(
|
||||
_START_TS, _END_TS, restored, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["b"]
|
||||
|
||||
def test_no_dedup_across_separate_indexing_runs(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A fresh checkpoint (new indexing run) should have an empty
|
||||
seen_document_ids, so previously-indexed docs are re-processed."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("a")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# First run
|
||||
cp1 = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, cp1, include_permissions=False
|
||||
)
|
||||
yielded, _ = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
|
||||
# Second run with a fresh checkpoint — same doc should appear again
|
||||
cp2 = _build_ready_checkpoint()
|
||||
assert len(cp2.seen_document_ids) == 0
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, cp2, include_permissions=False
|
||||
)
|
||||
yielded, _ = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
|
||||
def test_same_id_across_drives_not_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Graph item IDs are only unique within a drive. An item in drive B
|
||||
that happens to share an ID with an item already seen in drive A must
|
||||
NOT be skipped."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("shared-id")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint(drive_names=["DriveA", "DriveB"])
|
||||
|
||||
# Drive A: yields the item
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "shared-id"
|
||||
|
||||
# seen_document_ids should have been cleared when drive A finished
|
||||
assert len(checkpoint.seen_document_ids) == 0
|
||||
|
||||
# Drive B: same ID must be yielded again (different drive)
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "shared-id"
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
models = [
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Tests for LLM model fetch endpoints.
|
||||
|
||||
These tests verify the full request/response flow for fetching models
|
||||
from dynamic providers (Ollama, OpenRouter), including the
|
||||
from dynamic providers (Ollama, OpenRouter, Litellm), including the
|
||||
sync-to-DB behavior when provider_name is specified.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetLitellmAvailableModels:
|
||||
"""Tests for the Litellm proxy model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response(self) -> dict:
|
||||
"""Mock response from Litellm /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1700000001,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "gemini-pro",
|
||||
"object": "model",
|
||||
"created": 1700000002,
|
||||
"owned_by": "google",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted model list."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
|
||||
|
||||
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that provider_name and model_name are correctly extracted."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
gpt = next(r for r in results if r.model_name == "gpt-4o")
|
||||
assert gpt.provider_name == "openai"
|
||||
|
||||
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
|
||||
assert claude.provider_name == "anthropic"
|
||||
|
||||
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that results are alphabetically sorted by model_name."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
model_names = [r.model_name for r in results]
|
||||
assert model_names == sorted(model_names, key=str.lower)
|
||||
|
||||
def test_empty_data_raises_onyx_error(self) -> None:
|
||||
"""Test that empty model list raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No models found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_missing_data_key_raises_onyx_error(self) -> None:
|
||||
"""Test that response without 'data' key raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_skips_unparseable_entries(self) -> None:
|
||||
"""Test that malformed model entries are skipped without failing."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_with_bad_entry = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
# Missing required fields
|
||||
{"bad_field": "bad_value"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_with_bad_entry
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].model_name == "gpt-4o"
|
||||
|
||||
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
|
||||
"""Test that OnyxError is raised when all entries fail to parse."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_all_bad = {
|
||||
"data": [
|
||||
{"bad_field": "bad_value"},
|
||||
{"another_bad": 123},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_all_bad
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No compatible models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_api_base_trailing_slash_handled(self) -> None:
|
||||
"""Test that trailing slashes in api_base are handled correctly."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_litellm_response = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000/",
|
||||
api_key="test-key",
|
||||
)
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should call /v1/models without double slashes
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://localhost:4000/v1/models"
|
||||
|
||||
def test_connection_failure_raises_onyx_error(self) -> None:
|
||||
"""Test that connection failures are wrapped in OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
"""Test that a 401 response raises OnyxError with authentication message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="bad-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Authentication failed"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_404_raises_not_found_error(self) -> None:
|
||||
"""Test that a 404 response raises OnyxError with endpoint not found message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -153,7 +153,7 @@ dev = [
|
||||
"pytest-repeat==0.9.4",
|
||||
"pytest-xdist==3.8.0",
|
||||
"pytest==8.3.5",
|
||||
"release-tag==0.4.3",
|
||||
"release-tag==0.5.2",
|
||||
"reorder-python-imports-black==3.14.0",
|
||||
"ruff==0.12.0",
|
||||
"types-beautifulsoup4==4.12.0.3",
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -4485,7 +4485,7 @@ requires-dist = [
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
@@ -6338,16 +6338,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "release-tag"
|
||||
version = "0.4.3"
|
||||
version = "0.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -144,6 +144,7 @@ module.exports = {
|
||||
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
|
||||
"**/src/refresh-components/**/*.test.ts",
|
||||
"**/src/sections/**/*.test.ts",
|
||||
"**/src/components/**/*.test.ts",
|
||||
// Add more patterns here as you add more unit tests
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Interactive } from "@opal/core";
|
||||
import { Interactive, Disabled } from "@opal/core";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Variant / Prominence mappings for the matrix story
|
||||
@@ -9,8 +9,6 @@ const VARIANT_PROMINENCE_MAP: Record<string, string[]> = {
|
||||
default: ["primary", "secondary", "tertiary", "internal"],
|
||||
action: ["primary", "secondary", "tertiary", "internal"],
|
||||
danger: ["primary", "secondary", "tertiary", "internal"],
|
||||
select: ["light", "heavy"],
|
||||
sidebar: ["light"],
|
||||
none: [],
|
||||
};
|
||||
|
||||
@@ -35,39 +33,39 @@ export default meta;
|
||||
// Stories
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Basic Interactive.Base + Container with text content. */
|
||||
/** Basic Interactive.Stateless + Container with text content. */
|
||||
export const Default: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem", alignItems: "center" }}>
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Secondary</span>
|
||||
<span className="interactive-foreground">Secondary</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="primary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Primary</span>
|
||||
<span className="interactive-foreground">Primary</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="tertiary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Tertiary</span>
|
||||
<span className="interactive-foreground">Tertiary</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -91,11 +89,13 @@ export const VariantMatrix: StoryObj = {
|
||||
</div>
|
||||
|
||||
{prominences.length === 0 ? (
|
||||
<Interactive.Base variant="none" onClick={() => {}}>
|
||||
<Interactive.Stateless variant="none" onClick={() => {}}>
|
||||
<Interactive.Container border>
|
||||
<span>none (no prominence)</span>
|
||||
<span style={{ color: "var(--text-01)" }}>
|
||||
none (no prominence)
|
||||
</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
) : (
|
||||
<div style={{ display: "flex", gap: "0.5rem", flexWrap: "wrap" }}>
|
||||
{prominences.map((prominence) => (
|
||||
@@ -108,16 +108,18 @@ export const VariantMatrix: StoryObj = {
|
||||
gap: "0.25rem",
|
||||
}}
|
||||
>
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
// Cast required because the discriminated union can't be
|
||||
// resolved from dynamic strings at the type level.
|
||||
{...({ variant, prominence } as any)}
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>{prominence}</span>
|
||||
<span className="interactive-foreground">
|
||||
{prominence}
|
||||
</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
<span
|
||||
style={{
|
||||
fontSize: "0.625rem",
|
||||
@@ -141,16 +143,16 @@ export const Sizes: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", alignItems: "center", gap: "0.75rem" }}>
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
key={size}
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border heightVariant={size}>
|
||||
<span>{size}</span>
|
||||
<span className="interactive-foreground">{size}</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
@@ -160,15 +162,15 @@ export const Sizes: StoryObj = {
|
||||
export const WidthFull: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ width: 400 }}>
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border widthVariant="full">
|
||||
<span>Full width container</span>
|
||||
<span className="interactive-foreground">Full width container</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -178,73 +180,86 @@ export const Rounding: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem" }}>
|
||||
{ROUNDING_VARIANTS.map((rounding) => (
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
key={rounding}
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border roundingVariant={rounding}>
|
||||
<span>{rounding}</span>
|
||||
<span className="interactive-foreground">{rounding}</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
/** Disabled state prevents clicks and shows disabled styling. */
|
||||
export const Disabled: StoryObj = {
|
||||
export const DisabledStory: StoryObj = {
|
||||
name: "Disabled",
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem" }}>
|
||||
<Interactive.Base
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
disabled
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Disabled</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
<Disabled disabled>
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span className="interactive-foreground">Disabled</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateless>
|
||||
</Disabled>
|
||||
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Enabled</span>
|
||||
<span className="interactive-foreground">Enabled</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
/** Transient prop forces the hover/active visual state. */
|
||||
export const Transient: StoryObj = {
|
||||
/** Interaction override forces the hover/active visual state. */
|
||||
export const Interaction: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem" }}>
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
interaction="hover"
|
||||
onClick={() => {}}
|
||||
transient
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Forced hover</span>
|
||||
<span className="interactive-foreground">Forced hover</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
interaction="active"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span className="interactive-foreground">Forced active</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateless>
|
||||
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Normal</span>
|
||||
<span className="interactive-foreground">Normal (rest)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -253,25 +268,25 @@ export const Transient: StoryObj = {
|
||||
export const WithBorder: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem" }}>
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>With border</span>
|
||||
<span className="interactive-foreground">With border</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
|
||||
<Interactive.Base
|
||||
<Interactive.Stateless
|
||||
variant="default"
|
||||
prominence="secondary"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container>
|
||||
<span>Without border</span>
|
||||
<span className="interactive-foreground">Without border</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -279,51 +294,57 @@ export const WithBorder: StoryObj = {
|
||||
/** Using href to render as a link. */
|
||||
export const AsLink: StoryObj = {
|
||||
render: () => (
|
||||
<Interactive.Base variant="action" href="/settings">
|
||||
<Interactive.Stateless variant="action" href="/settings">
|
||||
<Interactive.Container border>
|
||||
<span>Go to Settings</span>
|
||||
<span className="interactive-foreground">Go to Settings</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateless>
|
||||
),
|
||||
};
|
||||
|
||||
/** Select variant with selected and unselected states. */
|
||||
/** Stateful select variant with selected and unselected states. */
|
||||
export const SelectVariant: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: "0.75rem" }}>
|
||||
<Interactive.Base
|
||||
variant="select"
|
||||
prominence="light"
|
||||
selected
|
||||
<Interactive.Stateful
|
||||
variant="select-light"
|
||||
state="selected"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Selected (light)</span>
|
||||
<span className="interactive-foreground">Selected (light)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateful>
|
||||
|
||||
<Interactive.Base variant="select" prominence="light" onClick={() => {}}>
|
||||
<Interactive.Container border>
|
||||
<span>Unselected (light)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
|
||||
<Interactive.Base
|
||||
variant="select"
|
||||
prominence="heavy"
|
||||
selected
|
||||
<Interactive.Stateful
|
||||
variant="select-light"
|
||||
state="empty"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Selected (heavy)</span>
|
||||
<span className="interactive-foreground">Unselected (light)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateful>
|
||||
|
||||
<Interactive.Base variant="select" prominence="heavy" onClick={() => {}}>
|
||||
<Interactive.Stateful
|
||||
variant="select-heavy"
|
||||
state="selected"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span>Unselected (heavy)</span>
|
||||
<span className="interactive-foreground">Selected (heavy)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
</Interactive.Stateful>
|
||||
|
||||
<Interactive.Stateful
|
||||
variant="select-heavy"
|
||||
state="empty"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<Interactive.Container border>
|
||||
<span className="interactive-foreground">Unselected (heavy)</span>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateful>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
@@ -89,7 +89,7 @@ export { default as SvgHistory } from "@opal/icons/history";
|
||||
export { default as SvgHourglass } from "@opal/icons/hourglass";
|
||||
export { default as SvgImage } from "@opal/icons/image";
|
||||
export { default as SvgImageSmall } from "@opal/icons/image-small";
|
||||
export { default as SvgImport } from "@opal/icons/import";
|
||||
export { default as SvgImport } from "@opal/icons/import-icon";
|
||||
export { default as SvgInfo } from "@opal/icons/info";
|
||||
export { default as SvgInfoSmall } from "@opal/icons/info-small";
|
||||
export { default as SvgKey } from "@opal/icons/key";
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { ReactNode, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatFileType, FileDescriptor } from "@/app/app/interfaces";
|
||||
import Attachment from "@/refresh-components/Attachment";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
@@ -9,10 +10,27 @@ import PreviewModal from "@/sections/modals/PreviewModal";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import ExpandableContentWrapper from "@/components/tools/ExpandableContentWrapper";
|
||||
|
||||
interface FileContainerProps {
|
||||
children: ReactNode;
|
||||
className?: string;
|
||||
id?: string;
|
||||
}
|
||||
|
||||
interface FileDisplayProps {
|
||||
files: FileDescriptor[];
|
||||
}
|
||||
|
||||
function FileContainer({ children, className, id }: FileContainerProps) {
|
||||
return (
|
||||
<div
|
||||
id={id}
|
||||
className={cn("flex w-full flex-col items-end gap-2 py-2", className)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function FileDisplay({ files }: FileDisplayProps) {
|
||||
const [close, setClose] = useState(true);
|
||||
const [previewingFile, setPreviewingFile] = useState<FileDescriptor | null>(
|
||||
@@ -41,7 +59,7 @@ export default function FileDisplay({ files }: FileDisplayProps) {
|
||||
)}
|
||||
|
||||
{textFiles.length > 0 && (
|
||||
<div id="onyx-file" className="flex flex-col items-end gap-2 py-2">
|
||||
<FileContainer id="onyx-file">
|
||||
{textFiles.map((file) => (
|
||||
<Attachment
|
||||
key={file.id}
|
||||
@@ -49,40 +67,36 @@ export default function FileDisplay({ files }: FileDisplayProps) {
|
||||
open={() => setPreviewingFile(file)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</FileContainer>
|
||||
)}
|
||||
|
||||
{imageFiles.length > 0 && (
|
||||
<div id="onyx-image" className="flex flex-col items-end gap-2 py-2">
|
||||
<FileContainer id="onyx-image">
|
||||
{imageFiles.map((file) => (
|
||||
<InMessageImage key={file.id} fileId={file.id} />
|
||||
))}
|
||||
</div>
|
||||
</FileContainer>
|
||||
)}
|
||||
|
||||
{csvFiles.length > 0 && (
|
||||
<div className="flex flex-col items-end gap-2 py-2">
|
||||
{csvFiles.map((file) => {
|
||||
return (
|
||||
<div key={file.id} className="w-fit">
|
||||
{close ? (
|
||||
<>
|
||||
<ExpandableContentWrapper
|
||||
fileDescriptor={file}
|
||||
close={() => setClose(false)}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Attachment
|
||||
open={() => setClose(true)}
|
||||
fileName={file.name || file.id}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<FileContainer className="overflow-auto">
|
||||
{csvFiles.map((file) =>
|
||||
close ? (
|
||||
<ExpandableContentWrapper
|
||||
key={file.id}
|
||||
fileDescriptor={file}
|
||||
close={() => setClose(false)}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
) : (
|
||||
<Attachment
|
||||
key={file.id}
|
||||
open={() => setClose(true)}
|
||||
fileName={file.name || file.id}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
</FileContainer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Button } from "@opal/components";
|
||||
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import type { AppMode } from "@/providers/QueryControllerProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
@@ -58,15 +58,15 @@ const footerMarkdownComponents = {
|
||||
*/
|
||||
export default function NRFChrome() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const { state, setAppMode } = useQueryController();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
const [modePopoverOpen, setModePopoverOpen] = useState(false);
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
const effectiveMode: AppMode =
|
||||
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
|
||||
|
||||
const customFooterContent =
|
||||
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
|
||||
@@ -78,7 +78,7 @@ export default function NRFChrome() {
|
||||
isPaidEnterpriseFeaturesEnabled &&
|
||||
settings.isSearchModeAvailable &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification;
|
||||
state.phase === "idle";
|
||||
|
||||
const showHeader = isMobile || showModeToggle;
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
const isStreaming = currentChatState === "streaming";
|
||||
|
||||
// Query controller for search/chat classification (EE feature)
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
const { submit: submitQuery, state } = useQueryController();
|
||||
|
||||
// Determine if retrieval (search) is enabled based on the agent
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
@@ -186,7 +186,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
}, [liveAgent]);
|
||||
|
||||
// Check if we're in search mode
|
||||
const isSearch = classification === "search";
|
||||
const isSearch =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// Anchor for scroll positioning (matches ChatPage pattern)
|
||||
const anchorMessage = messageHistory.at(-2) ?? messageHistory[0];
|
||||
@@ -317,7 +318,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
};
|
||||
|
||||
// Use submitQuery which will classify the query and either:
|
||||
// - Route to search (sets classification to "search" and shows SearchUI)
|
||||
// - Route to search (sets phase to "searching"/"search-results" and shows SearchUI)
|
||||
// - Route to chat (calls onChat callback)
|
||||
await submitQuery(submittedMessage, onChat);
|
||||
},
|
||||
|
||||
@@ -60,27 +60,28 @@ const CsvContent: React.FC<ContentComponentProps> = ({
|
||||
}
|
||||
|
||||
const csvData = await response.text();
|
||||
const rows = csvData.trim().split("\n");
|
||||
const rows = parseCSV(csvData.trim());
|
||||
const firstRow = rows[0];
|
||||
if (!firstRow) {
|
||||
throw new Error("CSV file is empty");
|
||||
}
|
||||
const parsedHeaders = firstRow.split(",");
|
||||
const parsedHeaders = firstRow;
|
||||
setHeaders(parsedHeaders);
|
||||
|
||||
const parsedData: Record<string, string>[] = rows.slice(1).map((row) => {
|
||||
const values = row.split(",");
|
||||
return parsedHeaders.reduce<Record<string, string>>(
|
||||
(obj, header, index) => {
|
||||
const val = values[index];
|
||||
if (val !== undefined) {
|
||||
obj[header] = val;
|
||||
}
|
||||
return obj;
|
||||
},
|
||||
{}
|
||||
);
|
||||
});
|
||||
const parsedData: Record<string, string>[] = rows
|
||||
.slice(1)
|
||||
.map((fields) => {
|
||||
return parsedHeaders.reduce<Record<string, string>>(
|
||||
(obj, header, index) => {
|
||||
const val = fields[index];
|
||||
if (val !== undefined) {
|
||||
obj[header] = val;
|
||||
}
|
||||
return obj;
|
||||
},
|
||||
{}
|
||||
);
|
||||
});
|
||||
setData(parsedData);
|
||||
csvCache.set(id, { headers: parsedHeaders, data: parsedData });
|
||||
} catch (error) {
|
||||
@@ -173,3 +174,53 @@ const csvCache = new Map<
|
||||
string,
|
||||
{ headers: string[]; data: Record<string, string>[] }
|
||||
>();
|
||||
|
||||
export function parseCSV(text: string): string[][] {
|
||||
const rows: string[][] = [];
|
||||
let field = "";
|
||||
let fields: string[] = [];
|
||||
let inQuotes = false;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const char = text[i];
|
||||
|
||||
if (inQuotes) {
|
||||
if (char === '"') {
|
||||
if (i + 1 < text.length && text[i + 1] === '"') {
|
||||
field += '"';
|
||||
i++;
|
||||
} else {
|
||||
inQuotes = false;
|
||||
}
|
||||
} else {
|
||||
field += char;
|
||||
}
|
||||
} else if (char === '"') {
|
||||
inQuotes = true;
|
||||
} else if (char === ",") {
|
||||
fields.push(field);
|
||||
field = "";
|
||||
} else if (char === "\n" || char === "\r") {
|
||||
if (char === "\r" && i + 1 < text.length && text[i + 1] === "\n") {
|
||||
i++;
|
||||
}
|
||||
fields.push(field);
|
||||
field = "";
|
||||
rows.push(fields);
|
||||
fields = [];
|
||||
} else {
|
||||
field += char;
|
||||
}
|
||||
}
|
||||
|
||||
if (inQuotes) {
|
||||
throw new Error("Malformed CSV: unterminated quoted field");
|
||||
}
|
||||
|
||||
if (field.length > 0 || fields.length > 0) {
|
||||
fields.push(field);
|
||||
rows.push(fields);
|
||||
}
|
||||
|
||||
return rows;
|
||||
}
|
||||
|
||||
@@ -40,12 +40,7 @@ export default function ExpandableContentWrapper({
|
||||
};
|
||||
|
||||
const Content = (
|
||||
<div
|
||||
className={cn(
|
||||
!expanded ? "w-message-default" : "w-full",
|
||||
"!rounded !rounded-lg overflow-y-hidden h-full"
|
||||
)}
|
||||
>
|
||||
<div className="w-message-default max-w-full !rounded-lg overflow-y-hidden h-full">
|
||||
<CardHeader className="w-full bg-background-tint-02 top-0 p-3">
|
||||
<div className="flex justify-between items-center">
|
||||
<Text className="text-ellipsis line-clamp-1" text03 mainUiAction>
|
||||
@@ -83,12 +78,10 @@ export default function ExpandableContentWrapper({
|
||||
)}
|
||||
>
|
||||
<CardContent className="p-0">
|
||||
{!expanded && (
|
||||
<ContentComponent
|
||||
fileDescriptor={fileDescriptor}
|
||||
expanded={expanded}
|
||||
/>
|
||||
)}
|
||||
<ContentComponent
|
||||
fileDescriptor={fileDescriptor}
|
||||
expanded={expanded}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
84
web/src/components/tools/parseCSV.test.ts
Normal file
84
web/src/components/tools/parseCSV.test.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { parseCSV } from "./CSVContent";
|
||||
|
||||
describe("parseCSV", () => {
|
||||
it("parses simple comma-separated rows", () => {
|
||||
expect(parseCSV("a,b,c\n1,2,3")).toEqual([
|
||||
["a", "b", "c"],
|
||||
["1", "2", "3"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("preserves commas inside quoted fields", () => {
|
||||
expect(parseCSV('name,address\nAlice,"123 Main St, Apt 4"')).toEqual([
|
||||
["name", "address"],
|
||||
["Alice", "123 Main St, Apt 4"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles escaped double quotes inside quoted fields", () => {
|
||||
expect(parseCSV('a,b\n"say ""hello""",world')).toEqual([
|
||||
["a", "b"],
|
||||
['say "hello"', "world"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles newlines inside quoted fields", () => {
|
||||
expect(parseCSV('a,b\n"line1\nline2",val')).toEqual([
|
||||
["a", "b"],
|
||||
["line1\nline2", "val"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles CRLF line endings", () => {
|
||||
expect(parseCSV("a,b\r\n1,2\r\n3,4")).toEqual([
|
||||
["a", "b"],
|
||||
["1", "2"],
|
||||
["3", "4"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles empty fields", () => {
|
||||
expect(parseCSV("a,b,c\n1,,3")).toEqual([
|
||||
["a", "b", "c"],
|
||||
["1", "", "3"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles a single element", () => {
|
||||
expect(parseCSV("a")).toEqual([["a"]]);
|
||||
});
|
||||
|
||||
it("handles a single row with no newline", () => {
|
||||
expect(parseCSV("a,b,c")).toEqual([["a", "b", "c"]]);
|
||||
});
|
||||
|
||||
it("handles quoted fields that are entirely empty", () => {
|
||||
expect(parseCSV('a,b\n"",val')).toEqual([
|
||||
["a", "b"],
|
||||
["", "val"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("handles multiple quoted fields with commas", () => {
|
||||
expect(parseCSV('"foo, bar","baz, qux"\n"1, 2","3, 4"')).toEqual([
|
||||
["foo, bar", "baz, qux"],
|
||||
["1, 2", "3, 4"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("throws on unterminated quoted field", () => {
|
||||
expect(() => parseCSV('a,b\n"foo,bar')).toThrow(
|
||||
"Malformed CSV: unterminated quoted field"
|
||||
);
|
||||
});
|
||||
|
||||
it("throws on unterminated quote at end of input", () => {
|
||||
expect(() => parseCSV('"unterminated')).toThrow(
|
||||
"Malformed CSV: unterminated quoted field"
|
||||
);
|
||||
});
|
||||
|
||||
it("returns empty array for empty input", () => {
|
||||
expect(parseCSV("")).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -1,55 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useCallback, useEffect } from "react";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { AppModeContext, AppMode } from "@/providers/AppModeProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
export interface AppModeProviderProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider for application mode (Search/Chat).
|
||||
*
|
||||
* This controls how user queries are handled:
|
||||
* - **search**: Forces search mode - quick document lookup
|
||||
* - **chat**: Forces chat mode - conversation with follow-up questions
|
||||
*
|
||||
* The initial mode is read from the user's persisted `default_app_mode` preference.
|
||||
* When search mode is unavailable (admin setting or no connectors), the mode is locked to "chat".
|
||||
*/
|
||||
export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { user } = useUser();
|
||||
const { isSearchModeAvailable } = useSettingsContext();
|
||||
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
const [appMode, setAppModeState] = useState<AppMode>("chat");
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) {
|
||||
setAppModeState("chat");
|
||||
return;
|
||||
}
|
||||
|
||||
if (persistedMode) {
|
||||
setAppModeState(persistedMode.toLowerCase() as AppMode);
|
||||
}
|
||||
}, [isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) return;
|
||||
setAppModeState(mode);
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable]
|
||||
);
|
||||
|
||||
return (
|
||||
<AppModeContext.Provider value={{ appMode, setAppMode }}>
|
||||
{children}
|
||||
</AppModeContext.Provider>
|
||||
);
|
||||
}
|
||||
@@ -8,14 +8,15 @@ import {
|
||||
SearchFullResponse,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { classifyQuery, searchDocuments } from "@/ee/lib/search/svc";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
QueryControllerContext,
|
||||
QueryClassification,
|
||||
QueryControllerValue,
|
||||
QueryState,
|
||||
AppMode,
|
||||
} from "@/providers/QueryControllerProvider";
|
||||
|
||||
interface QueryControllerProviderProps {
|
||||
@@ -25,19 +26,53 @@ interface QueryControllerProviderProps {
|
||||
export function QueryControllerProvider({
|
||||
children,
|
||||
}: QueryControllerProviderProps) {
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable: searchUiEnabled } = settings;
|
||||
const { user } = useUser();
|
||||
|
||||
// Query state
|
||||
// ── Merged query state (discriminated union) ──────────────────────────
|
||||
const [state, setState] = useState<QueryState>({
|
||||
phase: "idle",
|
||||
appMode: "chat",
|
||||
});
|
||||
|
||||
// Persistent app-mode preference — survives phase transitions and is
|
||||
// used to restore the correct mode when resetting back to idle.
|
||||
const appModeRef = useRef<AppMode>("chat");
|
||||
|
||||
// ── App mode sync from user preferences ───────────────────────────────
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
|
||||
useEffect(() => {
|
||||
let mode: AppMode = "chat";
|
||||
if (isPaidEnterpriseFeaturesEnabled && searchUiEnabled && persistedMode) {
|
||||
const lower = persistedMode.toLowerCase();
|
||||
mode = (["auto", "search", "chat"] as const).includes(lower as AppMode)
|
||||
? (lower as AppMode)
|
||||
: "chat";
|
||||
}
|
||||
appModeRef.current = mode;
|
||||
setState((prev) =>
|
||||
prev.phase === "idle" ? { phase: "idle", appMode: mode } : prev
|
||||
);
|
||||
}, [isPaidEnterpriseFeaturesEnabled, searchUiEnabled, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !searchUiEnabled) return;
|
||||
setState((prev) => {
|
||||
if (prev.phase !== "idle") return prev;
|
||||
appModeRef.current = mode;
|
||||
return { phase: "idle", appMode: mode };
|
||||
});
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, searchUiEnabled]
|
||||
);
|
||||
|
||||
// ── Ancillary state ───────────────────────────────────────────────────
|
||||
const [query, setQuery] = useState<string | null>(null);
|
||||
const [classification, setClassification] =
|
||||
useState<QueryClassification>(null);
|
||||
const [isClassifying, setIsClassifying] = useState(false);
|
||||
|
||||
// Search state
|
||||
const [searchResults, setSearchResults] = useState<SearchDocWithContent[]>(
|
||||
[]
|
||||
);
|
||||
@@ -51,7 +86,7 @@ export function QueryControllerProvider({
|
||||
const searchAbortRef = useRef<AbortController | null>(null);
|
||||
|
||||
/**
|
||||
* Perform document search
|
||||
* Perform document search (pure data-fetching, no phase side effects)
|
||||
*/
|
||||
const performSearch = useCallback(
|
||||
async (searchQuery: string, filters?: BaseFilters): Promise<void> => {
|
||||
@@ -85,19 +120,15 @@ export function QueryControllerProvider({
|
||||
setLlmSelectedDocIds(response.llm_selected_doc_ids ?? null);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
return;
|
||||
throw err;
|
||||
}
|
||||
|
||||
setError("Document search failed. Please try again.");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
} finally {
|
||||
// After we've performed a search, we automatically switch to "search" mode.
|
||||
// This is a "sticky" implementation; on purpose.
|
||||
setAppMode("search");
|
||||
}
|
||||
},
|
||||
[setAppMode]
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -112,8 +143,6 @@ export function QueryControllerProvider({
|
||||
const controller = new AbortController();
|
||||
classifyAbortRef.current = controller;
|
||||
|
||||
setIsClassifying(true);
|
||||
|
||||
try {
|
||||
const response: SearchFlowClassificationResponse = await classifyQuery(
|
||||
classifyQueryText,
|
||||
@@ -129,8 +158,6 @@ export function QueryControllerProvider({
|
||||
|
||||
setError("Query classification failed. Falling back to chat.");
|
||||
return "chat";
|
||||
} finally {
|
||||
setIsClassifying(false);
|
||||
}
|
||||
},
|
||||
[]
|
||||
@@ -148,62 +175,51 @@ export function QueryControllerProvider({
|
||||
setQuery(submitQuery);
|
||||
setError(null);
|
||||
|
||||
// 1.
|
||||
// We always route through chat if we're not Enterprise Enabled.
|
||||
//
|
||||
// 2.
|
||||
// We always route through chat if the admin has disabled the Search UI.
|
||||
//
|
||||
// 3.
|
||||
// We only go down the classification route if we're in the "New Session" tab.
|
||||
// Everywhere else, we always use the chat-flow.
|
||||
//
|
||||
// 4.
|
||||
// If we're in the "New Session" tab and the app-mode is "Chat", we continue with the chat-flow anyways.
|
||||
const currentAppMode = appModeRef.current;
|
||||
|
||||
// Always route through chat if:
|
||||
// 1. Not Enterprise Enabled
|
||||
// 2. Admin has disabled the Search UI
|
||||
// 3. Not in the "New Session" tab
|
||||
// 4. In "New Session" tab but app-mode is "Chat"
|
||||
if (
|
||||
!isPaidEnterpriseFeaturesEnabled ||
|
||||
!searchUiEnabled ||
|
||||
!appFocus.isNewSession() ||
|
||||
appMode === "chat"
|
||||
currentAppMode === "chat"
|
||||
) {
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
return;
|
||||
}
|
||||
|
||||
if (appMode === "search") {
|
||||
await performSearch(submitQuery, filters);
|
||||
setClassification("search");
|
||||
// Search mode: immediately show SearchUI with loading state
|
||||
if (currentAppMode === "search") {
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(submitQuery, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
return;
|
||||
}
|
||||
|
||||
// # Note (@raunakab)
|
||||
//
|
||||
// Interestingly enough, for search, we do:
|
||||
// 1. setClassification("search")
|
||||
// 2. performSearch
|
||||
//
|
||||
// But for chat, we do:
|
||||
// 1. performChat
|
||||
// 2. setClassification("chat")
|
||||
//
|
||||
// The ChatUI has a nice loading UI, so it's fine for us to prematurely set the
|
||||
// classification-state before the chat has finished loading.
|
||||
//
|
||||
// However, the SearchUI does not. Prematurely setting the classification-state
|
||||
// will lead to a slightly ugly UI.
|
||||
|
||||
// Auto mode: classify first, then route
|
||||
setState({ phase: "classifying" });
|
||||
try {
|
||||
const result = await performClassification(submitQuery);
|
||||
|
||||
if (result === "search") {
|
||||
setState({ phase: "searching" });
|
||||
await performSearch(submitQuery, filters);
|
||||
setClassification("search");
|
||||
setState({ phase: "search-results" });
|
||||
appModeRef.current = "search";
|
||||
} else {
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
@@ -213,14 +229,13 @@ export function QueryControllerProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
setClassification("chat");
|
||||
setState({ phase: "chat" });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
}
|
||||
},
|
||||
[
|
||||
appMode,
|
||||
appFocus,
|
||||
performClassification,
|
||||
performSearch,
|
||||
@@ -235,7 +250,14 @@ export function QueryControllerProvider({
|
||||
const refineSearch = useCallback(
|
||||
async (filters: BaseFilters): Promise<void> => {
|
||||
if (!query) return;
|
||||
await performSearch(query, filters);
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(query, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
},
|
||||
[query, performSearch]
|
||||
);
|
||||
@@ -254,7 +276,7 @@ export function QueryControllerProvider({
|
||||
}
|
||||
|
||||
setQuery(null);
|
||||
setClassification(null);
|
||||
setState({ phase: "idle", appMode: appModeRef.current });
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
setError(null);
|
||||
@@ -262,8 +284,8 @@ export function QueryControllerProvider({
|
||||
|
||||
const value: QueryControllerValue = useMemo(
|
||||
() => ({
|
||||
classification,
|
||||
isClassifying,
|
||||
state,
|
||||
setAppMode,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -272,8 +294,8 @@ export function QueryControllerProvider({
|
||||
reset,
|
||||
}),
|
||||
[
|
||||
classification,
|
||||
isClassifying,
|
||||
state,
|
||||
setAppMode,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -283,7 +305,7 @@ export function QueryControllerProvider({
|
||||
]
|
||||
);
|
||||
|
||||
// Sync classification state with navigation context
|
||||
// Sync state with navigation context
|
||||
useEffect(reset, [appFocus, reset]);
|
||||
|
||||
return (
|
||||
|
||||
@@ -56,7 +56,7 @@ export default function SearchCard({
|
||||
|
||||
return (
|
||||
<Interactive.Stateless onClick={handleClick} prominence="secondary">
|
||||
<Interactive.Container heightVariant="fit">
|
||||
<Interactive.Container heightVariant="fit" widthVariant="full">
|
||||
<Section alignItems="start" gap={0} padding={0.25}>
|
||||
{/* Title Row */}
|
||||
<Section
|
||||
|
||||
@@ -18,16 +18,17 @@ import { getTimeFilterDate, TimeFilter } from "@/lib/time";
|
||||
import useTags from "@/hooks/useTags";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { SvgCheck, SvgClock, SvgTag } from "@opal/icons";
|
||||
import FilterButton from "@/refresh-components/buttons/FilterButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { LineItemButton } from "@opal/components";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
@@ -51,22 +52,17 @@ const TIME_FILTER_OPTIONS: { value: TimeFilter; label: string }[] = [
|
||||
{ value: "year", label: "Past year" },
|
||||
];
|
||||
|
||||
// ============================================================================
|
||||
// SearchResults Component (default export)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Component for displaying search results with source filter sidebar.
|
||||
*/
|
||||
export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
// Available tags from backend
|
||||
const { tags: availableTags } = useTags();
|
||||
const {
|
||||
state,
|
||||
searchResults: results,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
refineSearch: onRefineSearch,
|
||||
} = useQueryController();
|
||||
|
||||
const prevErrorRef = useRef<string | null>(null);
|
||||
|
||||
// Show a toast notification when a new error occurs
|
||||
@@ -197,6 +193,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
|
||||
const showEmpty = !error && results.length === 0;
|
||||
|
||||
// Show a centered spinner while search is in-flight (after all hooks)
|
||||
if (state.phase === "searching") {
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex items-center justify-center">
|
||||
<SimpleLoader />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex flex-col gap-3">
|
||||
{/* ── Top row: Filters + Result count ── */}
|
||||
@@ -226,18 +231,19 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<Popover.Content align="start" width="md">
|
||||
<PopoverMenu>
|
||||
{TIME_FILTER_OPTIONS.map((opt) => (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={opt.value}
|
||||
onClick={() => {
|
||||
setTimeFilter(opt.value);
|
||||
setTimeFilterOpen(false);
|
||||
onRefineSearch(buildFilters({ time: opt.value }));
|
||||
}}
|
||||
selected={timeFilter === opt.value}
|
||||
state={timeFilter === opt.value ? "selected" : "empty"}
|
||||
icon={timeFilter === opt.value ? SvgCheck : SvgClock}
|
||||
>
|
||||
{opt.label}
|
||||
</LineItem>
|
||||
title={opt.label}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
))}
|
||||
</PopoverMenu>
|
||||
</Popover.Content>
|
||||
@@ -278,7 +284,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
t.tag_value === tag.tag_value
|
||||
);
|
||||
return (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={`${tag.tag_key}=${tag.tag_value}`}
|
||||
onClick={() => {
|
||||
const next = isSelected
|
||||
@@ -291,11 +297,12 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
setSelectedTags(next);
|
||||
onRefineSearch(buildFilters({ tags: next }));
|
||||
}}
|
||||
selected={isSelected}
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
icon={isSelected ? SvgCheck : SvgTag}
|
||||
>
|
||||
{tag.tag_value}
|
||||
</LineItem>
|
||||
title={tag.tag_value}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</PopoverMenu>
|
||||
@@ -357,7 +364,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<div className="flex-1 min-h-0 overflow-y-auto flex flex-col gap-4 px-1">
|
||||
<Section gap={0.25} height="fit">
|
||||
{sourcesWithMeta.map(({ source, meta, count }) => (
|
||||
<LineItem
|
||||
<LineItemButton
|
||||
key={source}
|
||||
icon={(props) => (
|
||||
<SourceIcon
|
||||
@@ -367,12 +374,15 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
/>
|
||||
)}
|
||||
onClick={() => handleSourceToggle(source)}
|
||||
selected={selectedSources.includes(source)}
|
||||
emphasized
|
||||
state={
|
||||
selectedSources.includes(source) ? "selected" : "empty"
|
||||
}
|
||||
title={meta.displayName}
|
||||
selectVariant="select-heavy"
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={<Text text03>{count}</Text>}
|
||||
>
|
||||
{meta.displayName}
|
||||
</LineItem>
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
</div>
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//
|
||||
// This is useful in determining what `SidebarTab` should be active, for example.
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
|
||||
@@ -66,31 +67,25 @@ export default function useAppFocus(): AppFocus {
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Check if we're viewing a shared chat
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
|
||||
// Check if we're on the user settings page
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
|
||||
// Check if we're on the agents page
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
|
||||
// Check search params for chat, agent, or project
|
||||
const chatId = searchParams.get(SEARCH_PARAM_NAMES.CHAT_ID);
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
|
||||
const agentId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
|
||||
const projectId = searchParams.get(SEARCH_PARAM_NAMES.PROJECT_ID);
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
|
||||
// No search params means we're on a new session
|
||||
return new AppFocus("new-session");
|
||||
// Memoize on the values that determine which AppFocus is constructed.
|
||||
// AppFocus is immutable, so same inputs → same instance.
|
||||
return useMemo(() => {
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
return new AppFocus("new-session");
|
||||
}, [pathname, chatId, agentId, projectId]);
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ function measure(el: HTMLElement): { x: number; y: number } | null {
|
||||
*/
|
||||
export default function useContainerCenter(): ContainerCenter {
|
||||
const pathname = usePathname();
|
||||
const { isSmallScreen } = useScreenSize();
|
||||
const { isMediumScreen } = useScreenSize();
|
||||
const [center, setCenter] = useState<{ x: number | null; y: number | null }>(
|
||||
() => {
|
||||
if (typeof document === "undefined") return NULL_CENTER;
|
||||
@@ -68,9 +68,9 @@ export default function useContainerCenter(): ContainerCenter {
|
||||
}, [pathname]);
|
||||
|
||||
return {
|
||||
centerX: isSmallScreen ? null : center.x,
|
||||
centerY: isSmallScreen ? null : center.y,
|
||||
hasContainerCenter: isSmallScreen
|
||||
centerX: isMediumScreen ? null : center.x,
|
||||
centerY: isMediumScreen ? null : center.y,
|
||||
hasContainerCenter: isMediumScreen
|
||||
? false
|
||||
: center.x !== null && center.y !== null,
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import {
|
||||
DESKTOP_SMALL_BREAKPOINT_PX,
|
||||
DESKTOP_MEDIUM_BREAKPOINT_PX,
|
||||
MOBILE_SIDEBAR_BREAKPOINT_PX,
|
||||
} from "@/lib/constants";
|
||||
import { useState, useCallback } from "react";
|
||||
@@ -12,6 +13,7 @@ export interface ScreenSize {
|
||||
width: number;
|
||||
isMobile: boolean;
|
||||
isSmallScreen: boolean;
|
||||
isMediumScreen: boolean;
|
||||
}
|
||||
|
||||
export default function useScreenSize(): ScreenSize {
|
||||
@@ -34,11 +36,13 @@ export default function useScreenSize(): ScreenSize {
|
||||
|
||||
const isMobile = sizes.width <= MOBILE_SIDEBAR_BREAKPOINT_PX;
|
||||
const isSmall = sizes.width <= DESKTOP_SMALL_BREAKPOINT_PX;
|
||||
const isMedium = sizes.width <= DESKTOP_MEDIUM_BREAKPOINT_PX;
|
||||
|
||||
return {
|
||||
height: sizes.height,
|
||||
width: sizes.width,
|
||||
isMobile: isMounted && isMobile,
|
||||
isSmallScreen: isMounted && isSmall,
|
||||
isMediumScreen: isMounted && isMedium,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ import {
|
||||
} from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import type { AppMode } from "@/providers/QueryControllerProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
@@ -82,7 +82,7 @@ import useBrowserInfo from "@/hooks/useBrowserInfo";
|
||||
*/
|
||||
function Header() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const { state, setAppMode } = useQueryController();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
@@ -108,7 +108,6 @@ function Header() {
|
||||
useChatSessions();
|
||||
const router = useRouter();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
|
||||
const customHeaderContent =
|
||||
settings?.enterpriseSettings?.custom_header_content;
|
||||
@@ -117,7 +116,8 @@ function Header() {
|
||||
// without this content still use.
|
||||
const pageWithHeaderContent = appFocus.isChat() || appFocus.isNewSession();
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
const effectiveMode: AppMode =
|
||||
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
|
||||
|
||||
const availableProjects = useMemo(() => {
|
||||
if (!projects) return [];
|
||||
@@ -323,7 +323,7 @@ function Header() {
|
||||
{isPaidEnterpriseFeaturesEnabled &&
|
||||
settings.isSearchModeAvailable &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification && (
|
||||
state.phase === "idle" && (
|
||||
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
|
||||
@@ -123,6 +123,7 @@ export const MAX_FILES_TO_SHOW = 3;
|
||||
// SIZES
|
||||
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
|
||||
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
|
||||
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
|
||||
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
|
||||
export const HORIZON_DISTANCE_PX = 800;
|
||||
export const LOGO_FOLDED_SIZE_PX = 24;
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext } from "react";
|
||||
import { eeGated } from "@/ce";
|
||||
import { AppModeProvider as EEAppModeProvider } from "@/ee/providers/AppModeProvider";
|
||||
|
||||
export type AppMode = "auto" | "search" | "chat";
|
||||
|
||||
interface AppModeContextValue {
|
||||
appMode: AppMode;
|
||||
setAppMode: (mode: AppMode) => void;
|
||||
}
|
||||
|
||||
export const AppModeContext = createContext<AppModeContextValue>({
|
||||
appMode: "chat",
|
||||
setAppMode: () => undefined,
|
||||
});
|
||||
|
||||
export function useAppMode(): AppModeContextValue {
|
||||
return useContext(AppModeContext);
|
||||
}
|
||||
|
||||
export const AppModeProvider = eeGated(EEAppModeProvider);
|
||||
@@ -24,7 +24,7 @@
|
||||
* 4. **ProviderContextProvider** - LLM provider configuration
|
||||
* 5. **ModalProvider** - Global modal state management
|
||||
* 6. **AppSidebarProvider** - Sidebar open/closed state
|
||||
* 7. **AppModeProvider** - Search/Chat mode selection
|
||||
* 7. **QueryControllerProvider** - Search/Chat mode + query lifecycle
|
||||
*
|
||||
* ## Usage
|
||||
*
|
||||
@@ -40,7 +40,7 @@
|
||||
* - `useSettingsContext()` - from SettingsProvider
|
||||
* - `useUser()` - from UserProvider
|
||||
* - `useAppBackground()` - from AppBackgroundProvider
|
||||
* - `useAppMode()` - from AppModeProvider
|
||||
* - `useQueryController()` - from QueryControllerProvider (includes appMode)
|
||||
* - etc.
|
||||
*
|
||||
* @TODO(@raunakab): The providers wrapped by this component are currently
|
||||
@@ -65,7 +65,6 @@ import { User } from "@/lib/types";
|
||||
import { ModalProvider } from "@/components/context/ModalContext";
|
||||
import { AuthTypeMetadata } from "@/lib/userSS";
|
||||
import { AppSidebarProvider } from "@/providers/AppSidebarProvider";
|
||||
import { AppModeProvider } from "@/providers/AppModeProvider";
|
||||
import { AppBackgroundProvider } from "@/providers/AppBackgroundProvider";
|
||||
import { QueryControllerProvider } from "@/providers/QueryControllerProvider";
|
||||
import ToastProvider from "@/providers/ToastProvider";
|
||||
@@ -96,11 +95,9 @@ export default function AppProvider({
|
||||
<ProviderContextProvider>
|
||||
<ModalProvider user={user}>
|
||||
<AppSidebarProvider folded={!!folded}>
|
||||
<AppModeProvider>
|
||||
<QueryControllerProvider>
|
||||
<ToastProvider>{children}</ToastProvider>
|
||||
</QueryControllerProvider>
|
||||
</AppModeProvider>
|
||||
<QueryControllerProvider>
|
||||
<ToastProvider>{children}</ToastProvider>
|
||||
</QueryControllerProvider>
|
||||
</AppSidebarProvider>
|
||||
</ModalProvider>
|
||||
</ProviderContextProvider>
|
||||
|
||||
@@ -5,13 +5,20 @@ import { eeGated } from "@/ce";
|
||||
import { QueryControllerProvider as EEQueryControllerProvider } from "@/ee/providers/QueryControllerProvider";
|
||||
import { SearchDocWithContent, BaseFilters } from "@/lib/search/interfaces";
|
||||
|
||||
export type QueryClassification = "search" | "chat" | null;
|
||||
export type AppMode = "auto" | "search" | "chat";
|
||||
|
||||
export type QueryState =
|
||||
| { phase: "idle"; appMode: AppMode }
|
||||
| { phase: "classifying" }
|
||||
| { phase: "searching" }
|
||||
| { phase: "search-results" }
|
||||
| { phase: "chat" };
|
||||
|
||||
export interface QueryControllerValue {
|
||||
/** Classification state: null (idle), "search", or "chat" */
|
||||
classification: QueryClassification;
|
||||
/** Whether or not the currently submitted query is being actively classified by the backend */
|
||||
isClassifying: boolean;
|
||||
/** Single state variable encoding both the query lifecycle phase and (when idle) the user's mode selection. */
|
||||
state: QueryState;
|
||||
/** Update the app mode. Only takes effect when idle. No-op in CE or when search is unavailable. */
|
||||
setAppMode: (mode: AppMode) => void;
|
||||
/** Search results (empty if chat or not yet searched) */
|
||||
searchResults: SearchDocWithContent[];
|
||||
/** Document IDs selected by the LLM as most relevant */
|
||||
@@ -31,8 +38,8 @@ export interface QueryControllerValue {
|
||||
}
|
||||
|
||||
export const QueryControllerContext = createContext<QueryControllerValue>({
|
||||
classification: null,
|
||||
isClassifying: false,
|
||||
state: { phase: "idle", appMode: "chat" },
|
||||
setAppMode: () => undefined,
|
||||
searchResults: [],
|
||||
llmSelectedDocIds: null,
|
||||
error: null,
|
||||
|
||||
@@ -2,6 +2,8 @@ import React from "react";
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import ButtonRenaming from "./ButtonRenaming";
|
||||
|
||||
const noop = () => {};
|
||||
|
||||
const meta: Meta<typeof ButtonRenaming> = {
|
||||
title: "refresh-components/buttons/ButtonRenaming",
|
||||
component: ButtonRenaming,
|
||||
@@ -28,35 +30,23 @@ type Story = StoryObj<typeof ButtonRenaming>;
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
initialName: "My Chat Session",
|
||||
onRename: async (name: string) => {
|
||||
console.log("Renamed to:", name);
|
||||
},
|
||||
onClose: () => {
|
||||
console.log("Closed");
|
||||
},
|
||||
onRename: async () => {},
|
||||
onClose: noop,
|
||||
},
|
||||
};
|
||||
|
||||
export const EmptyName: Story = {
|
||||
args: {
|
||||
initialName: null,
|
||||
onRename: async (name: string) => {
|
||||
console.log("Renamed to:", name);
|
||||
},
|
||||
onClose: () => {
|
||||
console.log("Closed");
|
||||
},
|
||||
onRename: async () => {},
|
||||
onClose: noop,
|
||||
},
|
||||
};
|
||||
|
||||
export const LongName: Story = {
|
||||
args: {
|
||||
initialName: "This is a very long chat session name that should overflow",
|
||||
onRename: async (name: string) => {
|
||||
console.log("Renamed to:", name);
|
||||
},
|
||||
onClose: () => {
|
||||
console.log("Closed");
|
||||
},
|
||||
onRename: async () => {},
|
||||
onClose: noop,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -72,7 +72,6 @@ import { eeGated } from "@/ce";
|
||||
import EESearchUI from "@/ee/sections/SearchUI";
|
||||
const SearchUI = eeGated(EESearchUI);
|
||||
import { motion, AnimatePresence } from "motion/react";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
|
||||
interface FadeProps {
|
||||
show: boolean;
|
||||
@@ -129,7 +128,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
type: "success",
|
||||
},
|
||||
});
|
||||
const { setAppMode } = useAppMode();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Use SWR hooks for data fetching
|
||||
@@ -485,7 +483,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
finishOnboarding,
|
||||
]
|
||||
);
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
const { submit: submitQuery, state, setAppMode } = useQueryController();
|
||||
|
||||
const defaultAppMode =
|
||||
(user?.preferences?.default_app_mode?.toLowerCase() as "chat" | "search") ??
|
||||
@@ -493,12 +491,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const isNewSession = appFocus.isNewSession();
|
||||
|
||||
const isSearch =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// 1. Reset the app-mode back to the user's default when navigating back to the "New Sessions" tab.
|
||||
// 2. If we're navigating away from the "New Session" tab after performing a search, we reset the app-input-bar.
|
||||
useEffect(() => {
|
||||
if (isNewSession) setAppMode(defaultAppMode);
|
||||
if (!isNewSession && classification === "search") resetInputBar();
|
||||
}, [isNewSession, defaultAppMode, classification, resetInputBar, setAppMode]);
|
||||
if (!isNewSession && isSearch) resetInputBar();
|
||||
}, [isNewSession, defaultAppMode, isSearch, resetInputBar, setAppMode]);
|
||||
|
||||
const handleSearchDocumentClick = useCallback(
|
||||
(doc: MinimalOnyxDocument) => setPresentingDocument(doc),
|
||||
@@ -607,7 +608,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const hasStarterMessages = (liveAgent?.starter_messages?.length ?? 0) > 0;
|
||||
|
||||
const isSearch = classification === "search";
|
||||
const gridStyle = {
|
||||
gridTemplateColumns: "1fr",
|
||||
gridTemplateRows: isSearch
|
||||
@@ -735,7 +735,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<Fade
|
||||
show={
|
||||
(appFocus.isNewSession() || appFocus.isAgent()) &&
|
||||
!classification
|
||||
(state.phase === "idle" || state.phase === "classifying")
|
||||
}
|
||||
className="w-full flex-1 flex flex-col items-center justify-end"
|
||||
>
|
||||
@@ -764,7 +764,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
{/* OnboardingUI */}
|
||||
{(appFocus.isNewSession() || appFocus.isAgent()) &&
|
||||
!classification &&
|
||||
(state.phase === "idle" ||
|
||||
state.phase === "classifying") &&
|
||||
(showOnboarding || !user?.personalization?.name) &&
|
||||
!onboardingDismissed && (
|
||||
<OnboardingFlow
|
||||
@@ -799,7 +800,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<div
|
||||
className={cn(
|
||||
"transition-all duration-150 ease-in-out overflow-hidden",
|
||||
classification === "search" ? "h-[14px]" : "h-0"
|
||||
isSearch ? "h-[14px]" : "h-0"
|
||||
)}
|
||||
/>
|
||||
<AppInputBar
|
||||
|
||||
@@ -19,7 +19,6 @@ import useCCPairs from "@/hooks/useCCPairs";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { cn, isImageFile } from "@/lib/utils";
|
||||
import { Disabled } from "@opal/core";
|
||||
@@ -120,7 +119,10 @@ const AppInputBar = React.memo(
|
||||
const filesContentRef = useRef<HTMLDivElement>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const { user } = useUser();
|
||||
const { isClassifying, classification } = useQueryController();
|
||||
const { state } = useQueryController();
|
||||
const isClassifying = state.phase === "classifying";
|
||||
const isSearchActive =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
|
||||
// Expose reset and focus methods to parent via ref
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
@@ -140,12 +142,10 @@ const AppInputBar = React.memo(
|
||||
setMessage(initialMessage);
|
||||
}
|
||||
}, [initialMessage]);
|
||||
|
||||
const { appMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const appMode = state.phase === "idle" ? state.appMode : undefined;
|
||||
const isSearchMode =
|
||||
(appFocus.isNewSession() && appMode === "search") ||
|
||||
classification === "search";
|
||||
(appFocus.isNewSession() && appMode === "search") || isSearchActive;
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useForcedTools();
|
||||
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =
|
||||
|
||||
@@ -77,7 +77,6 @@ import { Notification, NotificationType } from "@/interfaces/settings";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
import ChatSearchCommandMenu from "@/sections/sidebar/ChatSearchCommandMenu";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
|
||||
// Visible-agents = pinned-agents + current-agent (if current-agent not in pinned-agents)
|
||||
@@ -206,8 +205,7 @@ const MemoizedAppSidebarInner = memo(
|
||||
const combinedSettings = useSettingsContext();
|
||||
const posthog = usePostHog();
|
||||
const { newTenantInfo, invitationInfo } = useModalContext();
|
||||
const { setAppMode } = useAppMode();
|
||||
const { reset } = useQueryController();
|
||||
const { setAppMode, reset } = useQueryController();
|
||||
|
||||
// Use SWR hooks for data fetching
|
||||
const {
|
||||
|
||||
Reference in New Issue
Block a user