mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 21:25:44 +00:00
Compare commits
12 Commits
test-tests
...
v2.11.0-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b |
39
.vscode/launch.json
vendored
39
.vscode/launch.json
vendored
@@ -149,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -587,6 +605,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
# NOTE: These strings must be formatted in the same way as the output of
|
||||
# DocumentAccess::to_acl.
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ of "minimum value clipping".
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
contents. If there are lots of updates, this may miss.
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_opensearch_filtered_access_control_list(
|
||||
access: DocumentAccess,
|
||||
) -> list[str]:
|
||||
"""Generates an access control list with PUBLIC_DOC_PAT removed.
|
||||
|
||||
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
|
||||
"""
|
||||
access_control_list = access.to_acl()
|
||||
access_control_list.discard(PUBLIC_DOC_PAT)
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -152,10 +166,9 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
chunk.access
|
||||
),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
@@ -578,8 +591,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
|
||||
generate_opensearch_filtered_access_control_list(
|
||||
update_request.access
|
||||
)
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
@@ -625,13 +640,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
# TODO(andrei): Add a param for whether to retrieve hidden docs.
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
@@ -646,6 +659,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
@@ -672,9 +687,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
@@ -688,6 +700,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
|
||||
@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
def serialize_datetime_fields_to_epoch_seconds(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
Serializes datetime fields to seconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
return int(value.timestamp())
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses seconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
@@ -354,11 +353,9 @@ class DocumentSchema:
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
"format": "epoch_second",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
@@ -366,14 +363,21 @@ class DocumentSchema:
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
# should have no effect on queries.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
# If a user's access set contains at least one entry from this
|
||||
# set, the user should be able to retrieve this document. This
|
||||
# only applies if public is set to false; public non-hidden
|
||||
# documents are always visible to anyone in a given tenancy
|
||||
# regardless of this field.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
# Whether the doc is hidden from search results.
|
||||
# Should clobber all other access search filters, namely
|
||||
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
|
||||
# search implementations to guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
@@ -447,7 +451,6 @@ class DocumentSchema:
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
@@ -91,6 +106,11 @@ assert (
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
@@ -103,6 +123,8 @@ class DocumentQuery:
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -120,6 +142,8 @@ class DocumentQuery:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the document retrieval query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
@@ -136,28 +160,21 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
max_chunk_size=max_chunk_size,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
@@ -195,15 +212,22 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
# Delete hidden docs too.
|
||||
include_hidden=True,
|
||||
access_control_list=None,
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=document_id,
|
||||
)
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
@@ -217,19 +241,25 @@ class DocumentQuery:
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
NOTE: This query can be directly supplied to the OpenSearch client, but
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
num_candidates: The number of neighbors to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the hybrid search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
@@ -243,31 +273,47 @@ class DocumentQuery:
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
@@ -294,7 +340,8 @@ class DocumentQuery:
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -305,6 +352,7 @@ class DocumentQuery:
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the title.
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -313,6 +361,7 @@ class DocumentQuery:
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the content.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -322,36 +371,273 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
# Either fuzzy match on the analyzed title (boosted 2x), or
|
||||
# exact match on exact title keywords (no OpenSearch
|
||||
# analysis done on the title). See
|
||||
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
# Returns the score of the best match of the fields above.
|
||||
# See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
# Fuzzy match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
# Exact match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
include_hidden: bool,
|
||||
access_control_list: list[str] | None,
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
max_chunk_size: int | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns filters to be passed into the "filter" key of a search query.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
access_control_list: Access control list for the documents to
|
||||
retrieve. If None, there is no restriction on the documents that
|
||||
can be retrieved. If not None, only public documents can be
|
||||
retrieved, or non-public documents where at least one acl
|
||||
provided here is present in the document's acl list.
|
||||
source_types: If supplied, only documents of one of these source
|
||||
types will be retrieved.
|
||||
tags: If supplied, only documents with an entry in their metadata
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
updated time, we assume some default age of
|
||||
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
|
||||
updated.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
max_chunk_size: The type of chunk to retrieve, specified by the
|
||||
maximum number of tokens it can hold. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
)
|
||||
return tag_filter
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
|
||||
}
|
||||
}
|
||||
)
|
||||
if time_cutoff < datetime.now(timezone.utc) - timedelta(
|
||||
days=ASSUMED_DOCUMENT_AGE_DAYS
|
||||
):
|
||||
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
|
||||
# ago, we include documents which have no
|
||||
# LAST_UPDATED_FIELD_NAME value.
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
|
||||
}
|
||||
}
|
||||
)
|
||||
return time_cutoff_filter
|
||||
|
||||
def _get_chunk_index_filter(
|
||||
min_chunk_index: int | None, max_chunk_index: int | None
|
||||
) -> dict[str, Any]:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
return range_clause
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
|
||||
|
||||
if access_control_list is not None:
|
||||
# If an access control list is provided, the caller can only
|
||||
# retrieve public documents, and non-public documents where at least
|
||||
# one acl provided here is present in the document's acl list. If
|
||||
# there is explicitly no list provided, we make no restrictions on
|
||||
# the documents that can be retrieved.
|
||||
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
|
||||
|
||||
if source_types:
|
||||
# If at least one source type is provided, the caller will only
|
||||
# retrieve documents whose source type is present in this input
|
||||
# list.
|
||||
filter_clauses.append(_get_source_type_filter(source_types))
|
||||
|
||||
if tags:
|
||||
# If at least one tag is provided, the caller will only retrieve
|
||||
# documents where at least one tag provided here is present in the
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
if document_sets:
|
||||
# If at least one document set is provided, the caller will only
|
||||
# retrieve documents where at least one document set provided here
|
||||
# is present in the document's document sets list.
|
||||
filter_clauses.append(_get_document_set_filter(document_sets))
|
||||
|
||||
if user_file_ids:
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs. Note that these IDs correspond to Onyx documents whereas
|
||||
# the entries retrieved from the document index correspond to Onyx
|
||||
# document chunks.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
# cutoff. For documents which do not have a value for
|
||||
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
|
||||
# purposes of time cutoff.
|
||||
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
filter_clauses.append(
|
||||
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
if max_chunk_size is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
return filter_clauses
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
@@ -378,4 +664,5 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
|
||||
# Determine stop reason - check if message indicates user cancelled
|
||||
stop_reason: str | None = None
|
||||
if chat_message.message:
|
||||
if "Generation was stopped" in chat_message.message:
|
||||
if "generation was stopped" in chat_message.message.lower():
|
||||
stop_reason = "user_cancelled"
|
||||
|
||||
# Add overall stop packet at the end
|
||||
|
||||
@@ -191,6 +191,18 @@ autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Listens for Discord messages and responds with answers
|
||||
# for all guilds/channels that the OnyxBot has been added to.
|
||||
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
|
||||
[program:discord_bot]
|
||||
command=python onyx/onyxbot/discord/client.py
|
||||
stdout_logfile=/var/log/discord_bot.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Pushes all logs from the above programs to stdout
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
@@ -206,6 +218,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_user_file_processing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/discord_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
/var/log/mcp_server.log
|
||||
/var/log/mcp_server.err.log
|
||||
|
||||
@@ -8,14 +8,22 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
generate_opensearch_filtered_access_control_list,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
@@ -42,14 +50,22 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
|
||||
|
||||
def _create_test_document_chunk(
|
||||
document_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
tenant_state: TenantState,
|
||||
chunk_index: int = 0,
|
||||
content_vector: list[float] | None = None,
|
||||
title: str | None = None,
|
||||
title_vector: list[float] | None = None,
|
||||
public: bool = True,
|
||||
hidden: bool = False,
|
||||
document_access: DocumentAccess = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
),
|
||||
source_type: DocumentSource = DocumentSource.FILE,
|
||||
last_updated: datetime | None = None,
|
||||
) -> DocumentChunk:
|
||||
if content_vector is None:
|
||||
# Generate dummy vector - 128 dimensions for fast testing.
|
||||
@@ -59,11 +75,6 @@ def _create_test_document_chunk(
|
||||
if title is not None and title_vector is None:
|
||||
title_vector = [0.2] * 128
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# We only store millisecond precision, so to make sure asserts work in this
|
||||
# test file manually lose some precision from datetime.now().
|
||||
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
|
||||
|
||||
return DocumentChunk(
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
@@ -71,11 +82,13 @@ def _create_test_document_chunk(
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type="test_source",
|
||||
source_type=source_type.value,
|
||||
metadata_list=None,
|
||||
last_updated=now,
|
||||
public=public,
|
||||
access_control_list=[],
|
||||
last_updated=last_updated,
|
||||
public=document_access.is_public,
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
document_access
|
||||
),
|
||||
hidden=hidden,
|
||||
global_boost=0,
|
||||
semantic_identifier="Test semantic identifier",
|
||||
@@ -331,6 +344,9 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Content to retrieve",
|
||||
tenant_state=tenant_state,
|
||||
# We only store second precision, so to make sure asserts work in
|
||||
# this test we'll deliberately lose some precision.
|
||||
last_updated=datetime.now(timezone.utc).replace(microsecond=0),
|
||||
)
|
||||
test_client.index_document(document=original_doc)
|
||||
|
||||
@@ -471,6 +487,8 @@ class TestOpenSearchClient:
|
||||
search_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="delete-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -483,6 +501,8 @@ class TestOpenSearchClient:
|
||||
keep_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="keep-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -510,7 +530,6 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Original content",
|
||||
tenant_state=tenant_state,
|
||||
public=True,
|
||||
hidden=False,
|
||||
)
|
||||
test_client.index_document(document=doc)
|
||||
@@ -561,10 +580,13 @@ class TestOpenSearchClient:
|
||||
properties_to_update={"hidden": True},
|
||||
)
|
||||
|
||||
def test_search_basic(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests basic search functionality."""
|
||||
"""Tests hybrid search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
@@ -574,24 +596,24 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index multiple documents with different content and vectors.
|
||||
# Index documents.
|
||||
docs = {
|
||||
"search-doc-1": _create_test_document_chunk(
|
||||
document_id="search-doc-1",
|
||||
"doc-1": _create_test_document_chunk(
|
||||
document_id="doc-1",
|
||||
chunk_index=0,
|
||||
content="Python programming language tutorial",
|
||||
content_vector=_generate_test_vector(0.1),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-2": _create_test_document_chunk(
|
||||
document_id="search-doc-2",
|
||||
"doc-2": _create_test_document_chunk(
|
||||
document_id="doc-2",
|
||||
chunk_index=0,
|
||||
content="How to make cheese",
|
||||
content_vector=_generate_test_vector(0.2),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-3": _create_test_document_chunk(
|
||||
document_id="search-doc-3",
|
||||
"doc-3": _create_test_document_chunk(
|
||||
document_id="doc-3",
|
||||
chunk_index=0,
|
||||
content="C++ for newborns",
|
||||
content_vector=_generate_test_vector(0.15),
|
||||
@@ -613,78 +635,10 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 3
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id
|
||||
in ["search-doc-1", "search-doc-2", "search-doc-3"]
|
||||
for chunk in results
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
|
||||
# Make sure there is some kind of match highlight for the first hit. We
|
||||
# don't expect highlights for any other hit.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents.
|
||||
docs = {
|
||||
"pipeline-doc-1": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-1",
|
||||
chunk_index=0,
|
||||
content="Machine learning algorithms for single-celled organisms",
|
||||
content_vector=_generate_test_vector(0.3),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"pipeline-doc-2": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-2",
|
||||
chunk_index=0,
|
||||
content="Deep learning shallow neural networks",
|
||||
content_vector=_generate_test_vector(0.35),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query.
|
||||
query_text = "machine learning"
|
||||
query_vector = _generate_test_vector(0.32)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -693,23 +647,26 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 2
|
||||
assert len(results) == len(docs)
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
|
||||
for chunk in results
|
||||
)
|
||||
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
for i, chunk in enumerate(results):
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Make sure there is some kind of match highlight only for the first
|
||||
# result. The other results are so bad they're not expected to have
|
||||
# match highlights.
|
||||
if i == 0:
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_empty_index(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search on an empty index returns an empty list."""
|
||||
# Precondition.
|
||||
@@ -731,19 +688,28 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_filters(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests search filters for public/hidden documents and tenant isolation.
|
||||
Tests search filters for ACL, hidden documents, and tenant isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
@@ -757,29 +723,47 @@ class TestOpenSearchClient:
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc-1": _create_test_document_chunk(
|
||||
document_id="public-doc-1",
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc-1": _create_test_document_chunk(
|
||||
document_id="hidden-doc-1",
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-1": _create_test_document_chunk(
|
||||
document_id="private-doc-1",
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
public=False,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
@@ -787,7 +771,6 @@ class TestOpenSearchClient:
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
@@ -798,9 +781,6 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search with default filters (public=True, hidden=False).
|
||||
# The DocumentQuery.get_hybrid_search_query uses filters that should
|
||||
# only return public, non-hidden documents.
|
||||
query_text = "document content"
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -809,24 +789,41 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document.
|
||||
assert len(results) == 1
|
||||
assert results[0].document_chunk.document_id == "public-doc-1"
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# NOTE: This test is not explicitly testing for how well results are
|
||||
# ordered; we're just assuming which doc will be the first result here.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == docs["public-doc-1"]
|
||||
assert results[0].document_chunk == docs["public-doc"]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == docs["private-doc-user-a"]
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
@@ -849,52 +846,54 @@ class TestOpenSearchClient:
|
||||
# Vectors closer to query_vector (0.1) should rank higher.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="highly-relevant-1",
|
||||
document_id="highly-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence and machine learning transform technology",
|
||||
content_vector=_generate_test_vector(
|
||||
0.1
|
||||
), # Very close to query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="somewhat-relevant-1",
|
||||
document_id="somewhat-relevant",
|
||||
chunk_index=0,
|
||||
content="Computer programming with various languages",
|
||||
content_vector=_generate_test_vector(0.5), # Far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="not-very-relevant-1",
|
||||
document_id="not-very-relevant",
|
||||
chunk_index=0,
|
||||
content="Cooking recipes for delicious meals",
|
||||
content_vector=_generate_test_vector(
|
||||
0.9
|
||||
), # Very far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
# These should be filtered out by public/hidden filters.
|
||||
_create_test_document_chunk(
|
||||
document_id="hidden-but-relevant-1",
|
||||
document_id="hidden-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence research papers",
|
||||
content_vector=_generate_test_vector(0.05), # Very close but hidden.
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="private-but-relevant-1",
|
||||
document_id="private-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence industry analysis",
|
||||
content_vector=_generate_test_vector(0.08), # Very close but private.
|
||||
public=False,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
@@ -905,7 +904,7 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query matching "highly-relevant-1" most closely.
|
||||
# Search query matching "highly-relevant" most closely.
|
||||
query_text = "artificial intelligence"
|
||||
query_vector = _generate_test_vector(0.1)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -914,6 +913,9 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# Explicitly pass in an empty list to enforce private doc filtering.
|
||||
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -925,15 +927,15 @@ class TestOpenSearchClient:
|
||||
# Should only get public, non-hidden documents (3 out of 5).
|
||||
assert len(results) == 3
|
||||
result_ids = [chunk.document_chunk.document_id for chunk in results]
|
||||
assert "highly-relevant-1" in result_ids
|
||||
assert "somewhat-relevant-1" in result_ids
|
||||
assert "not-very-relevant-1" in result_ids
|
||||
assert "highly-relevant" in result_ids
|
||||
assert "somewhat-relevant" in result_ids
|
||||
assert "not-very-relevant" in result_ids
|
||||
# Filtered out by public/hidden constraints.
|
||||
assert "hidden-but-relevant-1" not in result_ids
|
||||
assert "private-but-relevant-1" not in result_ids
|
||||
assert "hidden-but-relevant" not in result_ids
|
||||
assert "private-but-relevant" not in result_ids
|
||||
|
||||
# Most relevant document should be first due to normalization pipeline.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant-1"
|
||||
# Most relevant document should be first.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant"
|
||||
|
||||
# Make sure there is some kind of match highlight for the most relevant
|
||||
# result.
|
||||
@@ -1014,6 +1016,8 @@ class TestOpenSearchClient:
|
||||
verify_query_x = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1026,6 +1030,8 @@ class TestOpenSearchClient:
|
||||
verify_query_y = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-y",
|
||||
tenant_state=tenant_y,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1113,6 +1119,8 @@ class TestOpenSearchClient:
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-1",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1133,3 +1141,176 @@ class TestOpenSearchClient:
|
||||
for chunk in doc1_chunks
|
||||
}
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
private ones.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for all documents.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="private-doc-user-a",
|
||||
tenant_state=tenant_state,
|
||||
# This is the input under test, notice None for acl.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
# Even though this doc is private, because we supplied None for acl we
|
||||
# were able to retrieve it.
|
||||
assert len(chunk_ids) == 1
|
||||
# Since this is a chunk ID, it will have the doc ID in it plus other
|
||||
# stuff we don't care about in this test.
|
||||
assert chunk_ids[0].startswith("private-doc-user-a")
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests the time cutoff filter works."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index docs with various ages.
|
||||
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
one_week_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
six_months_ago = datetime.now(timezone.utc) - timedelta(days=180)
|
||||
one_year_ago = datetime.now(timezone.utc) - timedelta(days=365)
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="one-day-ago",
|
||||
content="Good match",
|
||||
last_updated=one_day_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="one-year-ago",
|
||||
content="Good match",
|
||||
last_updated=one_year_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="no-last-updated",
|
||||
# Since we test for result ordering in the postconditions, let's
|
||||
# just make this content slightly less of a match with the query
|
||||
# so this test is not flaky from the ordering of the results.
|
||||
content="Still an ok match",
|
||||
last_updated=None,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for documents updated in the last week.
|
||||
last_week_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=one_week_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
last_six_months_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=six_months_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
last_week_results = test_client.search(
|
||||
body=last_week_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
last_six_months_results = test_client.search(
|
||||
body=last_six_months_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# We expect to only get one-day-ago.
|
||||
assert len(last_week_results) == 1
|
||||
assert last_week_results[0].document_chunk.document_id == "one-day-ago"
|
||||
# We expect to get one-day-ago and no-last-updated since six months >
|
||||
# ASSUMED_DOCUMENT_AGE_DAYS.
|
||||
assert len(last_six_months_results) == 2
|
||||
assert last_six_months_results[0].document_chunk.document_id == "one-day-ago"
|
||||
assert (
|
||||
last_six_months_results[1].document_chunk.document_id == "no-last-updated"
|
||||
)
|
||||
|
||||
@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
|
||||
provider_id = _activate_exa_provider(admin_user)
|
||||
assert isinstance(provider_id, int)
|
||||
|
||||
search_request = {"queries": ["latest ai research news"], "max_results": 3}
|
||||
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
|
||||
|
||||
lite_response = requests.post(
|
||||
f"{API_SERVER_URL}/web-search/search-lite",
|
||||
|
||||
@@ -221,6 +221,13 @@ services:
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
|
||||
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
|
||||
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
|
||||
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
|
||||
@@ -63,6 +63,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -82,6 +82,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -129,6 +129,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
|
||||
## CORS origins for MCP clients (comma-separated list)
|
||||
# MCP_SERVER_CORS_ORIGINS=
|
||||
|
||||
## Discord Bot Configuration
|
||||
## The Discord bot allows users to interact with Onyx from Discord servers
|
||||
## Bot token from Discord Developer Portal (required to enable the bot)
|
||||
# DISCORD_BOT_TOKEN=
|
||||
## Command prefix for bot commands (default: "!")
|
||||
# DISCORD_BOT_INVOKE_CHAR=!
|
||||
|
||||
## Celery Configuration
|
||||
# CELERY_BROKER_POOL_LIMIT=
|
||||
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=
|
||||
|
||||
@@ -582,29 +582,33 @@ else
|
||||
fi
|
||||
|
||||
# Ask for authentication schema
|
||||
echo ""
|
||||
print_info "Which authentication schema would you like to set up?"
|
||||
echo ""
|
||||
echo "1) Basic - Username/password authentication"
|
||||
echo "2) No Auth - Open access (development/testing)"
|
||||
echo ""
|
||||
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
|
||||
echo ""
|
||||
# echo ""
|
||||
# print_info "Which authentication schema would you like to set up?"
|
||||
# echo ""
|
||||
# echo "1) Basic - Username/password authentication"
|
||||
# echo "2) No Auth - Open access (development/testing)"
|
||||
# echo ""
|
||||
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
|
||||
# echo ""
|
||||
|
||||
case "${AUTH_CHOICE:-1}" in
|
||||
1)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Selected: Basic authentication"
|
||||
;;
|
||||
2)
|
||||
AUTH_SCHEMA="disabled"
|
||||
print_info "Selected: No authentication"
|
||||
;;
|
||||
*)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Invalid choice, using basic authentication"
|
||||
;;
|
||||
esac
|
||||
# case "${AUTH_CHOICE:-1}" in
|
||||
# 1)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Selected: Basic authentication"
|
||||
# ;;
|
||||
# # 2)
|
||||
# # AUTH_SCHEMA="disabled"
|
||||
# # print_info "Selected: No authentication"
|
||||
# # ;;
|
||||
# *)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Invalid choice, using basic authentication"
|
||||
# ;;
|
||||
# esac
|
||||
|
||||
# TODO (jessica): Uncomment this once no auth users still have an account
|
||||
# Use basic auth by default
|
||||
AUTH_SCHEMA="basic"
|
||||
|
||||
# Create .env file from template
|
||||
print_info "Creating .env file with your selections..."
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.19
|
||||
version: 0.4.20
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
{{- if .Values.discordbot.enabled }}
|
||||
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
|
||||
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-discordbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
|
||||
{{- with .Values.discordbot.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 8 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.discordbot.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: discordbot
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
|
||||
imagePullPolicy: {{ .Values.global.pullPolicy }}
|
||||
command: ["python", "onyx/onyxbot/discord/client.py"]
|
||||
resources:
|
||||
{{- toYaml .Values.discordbot.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx.envSecrets" . | nindent 12}}
|
||||
# Discord bot token - required for bot to connect
|
||||
{{- if .Values.discordbot.botToken }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
value: {{ .Values.discordbot.botToken | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.discordbot.botTokenSecretName }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.discordbot.botTokenSecretName }}
|
||||
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
|
||||
{{- end }}
|
||||
# Command prefix for bot commands (default: "!")
|
||||
{{- if .Values.discordbot.invokeChar }}
|
||||
- name: DISCORD_BOT_INVOKE_CHAR
|
||||
value: {{ .Values.discordbot.invokeChar | quote }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -701,6 +701,44 @@ celery_worker_user_file_processing:
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
# Discord bot for Onyx
|
||||
# The bot offloads message processing to scalable API pods via HTTP requests.
|
||||
discordbot:
|
||||
enabled: false # Disabled by default - requires bot token configuration
|
||||
# Bot token can be provided directly or via a Kubernetes secret
|
||||
# Option 1: Direct token (not recommended for production)
|
||||
botToken: ""
|
||||
# Option 2: Reference a Kubernetes secret (recommended)
|
||||
botTokenSecretName: "" # Name of the secret containing the bot token
|
||||
botTokenSecretKey: "token" # Key within the secret (default: "token")
|
||||
# Command prefix for bot commands (default: "!")
|
||||
invokeChar: "!"
|
||||
image:
|
||||
repository: onyxdotapp/onyx-backend
|
||||
tag: "" # Overrides the image tag whose default is the chart appVersion.
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend
|
||||
app: discord-bot
|
||||
deploymentLabels:
|
||||
app: discord-bot
|
||||
podSecurityContext:
|
||||
{}
|
||||
securityContext:
|
||||
{}
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2000Mi"
|
||||
volumes: []
|
||||
volumeMounts: []
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
slackbot:
|
||||
enabled: true
|
||||
replicaCount: 1
|
||||
@@ -1159,6 +1197,8 @@ configMap:
|
||||
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
|
||||
NOTIFY_SLACKBOT_NO_ANSWER: ""
|
||||
DISCORD_BOT_TOKEN: ""
|
||||
DISCORD_BOT_INVOKE_CHAR: ""
|
||||
# Logging
|
||||
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
|
||||
DISABLE_TELEMETRY: ""
|
||||
|
||||
21
web/lib/opal/src/icons/branch.tsx
Normal file
21
web/lib/opal/src/icons/branch.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBranch = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M4.75001 5C5.71651 5 6.50001 4.2165 6.50001 3.25C6.50001 2.2835 5.7165 1.5 4.75 1.5C3.78351 1.5 3.00001 2.2835 3.00001 3.25C3.00001 4.2165 3.78351 5 4.75001 5ZM4.75001 5L4.75001 6.24999M4.75 11C3.7835 11 3 11.7835 3 12.75C3 13.7165 3.7835 14.5 4.75 14.5C5.7165 14.5 6.5 13.7165 6.5 12.75C6.5 11.7835 5.71649 11 4.75 11ZM4.75 11L4.75001 6.24999M10.5 8.74997C10.5 9.71646 11.2835 10.5 12.25 10.5C13.2165 10.5 14 9.71646 14 8.74997C14 7.78347 13.2165 7 12.25 7C11.2835 7 10.5 7.78347 10.5 8.74997ZM10.5 8.74997L7.25001 8.74999C5.8693 8.74999 4.75001 7.6307 4.75001 6.24999"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgBranch;
|
||||
16
web/lib/opal/src/icons/circle.tsx
Normal file
16
web/lib/opal/src/icons/circle.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgCircle = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<circle cx="8" cy="8" r="6" strokeWidth={1.5} />
|
||||
</svg>
|
||||
);
|
||||
export default SvgCircle;
|
||||
21
web/lib/opal/src/icons/download.tsx
Normal file
21
web/lib/opal/src/icons/download.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgDownload = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M14 10V12.6667C14 13.3929 13.3929 14 12.6667 14H3.33333C2.60711 14 2 13.3929 2 12.6667V10M4.66667 6.66667L8 10M8 10L11.3333 6.66667M8 10L8 2"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgDownload;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
|
||||
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
|
||||
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
|
||||
export { default as SvgBranch } from "@opal/icons/branch";
|
||||
export { default as SvgBubbleText } from "@opal/icons/bubble-text";
|
||||
export { default as SvgCalendar } from "@opal/icons/calendar";
|
||||
export { default as SvgCheck } from "@opal/icons/check";
|
||||
@@ -36,6 +37,7 @@ export { default as SvgChevronLeft } from "@opal/icons/chevron-left";
|
||||
export { default as SvgChevronRight } from "@opal/icons/chevron-right";
|
||||
export { default as SvgChevronUp } from "@opal/icons/chevron-up";
|
||||
export { default as SvgChevronUpSmall } from "@opal/icons/chevron-up-small";
|
||||
export { default as SvgCircle } from "@opal/icons/circle";
|
||||
export { default as SvgClaude } from "@opal/icons/claude";
|
||||
export { default as SvgClipboard } from "@opal/icons/clipboard";
|
||||
export { default as SvgClock } from "@opal/icons/clock";
|
||||
@@ -46,6 +48,7 @@ export { default as SvgCopy } from "@opal/icons/copy";
|
||||
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
|
||||
export { default as SvgCpu } from "@opal/icons/cpu";
|
||||
export { default as SvgDevKit } from "@opal/icons/dev-kit";
|
||||
export { default as SvgDownload } from "@opal/icons/download";
|
||||
export { default as SvgDiscordMono } from "@opal/icons/DiscordMono";
|
||||
export { default as SvgDownloadCloud } from "@opal/icons/download-cloud";
|
||||
export { default as SvgEdit } from "@opal/icons/edit";
|
||||
@@ -135,6 +138,7 @@ export { default as SvgStep3End } from "@opal/icons/step3-end";
|
||||
export { default as SvgStop } from "@opal/icons/stop";
|
||||
export { default as SvgStopCircle } from "@opal/icons/stop-circle";
|
||||
export { default as SvgSun } from "@opal/icons/sun";
|
||||
export { default as SvgTerminal } from "@opal/icons/terminal";
|
||||
export { default as SvgTerminalSmall } from "@opal/icons/terminal-small";
|
||||
export { default as SvgTextLinesSmall } from "@opal/icons/text-lines-small";
|
||||
export { default as SvgThumbsDown } from "@opal/icons/thumbs-down";
|
||||
|
||||
22
web/lib/opal/src/icons/terminal.tsx
Normal file
22
web/lib/opal/src/icons/terminal.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgTerminal = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2.66667 11.3333L6.66667 7.33331L2.66667 3.33331M8.00001 12.6666H13.3333"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgTerminal;
|
||||
@@ -31,6 +31,7 @@ import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
@@ -135,7 +136,7 @@ function BedrockFormInternals({
|
||||
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<SelectorFormField
|
||||
@@ -176,7 +177,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
label="AWS Access Key ID"
|
||||
@@ -191,7 +192,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
label="AWS Bedrock Long-term API Key"
|
||||
|
||||
@@ -643,6 +643,7 @@ export function useChatController({
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
let files = projectFilesToFileDescriptors(currentMessageFiles);
|
||||
let packets: Packet[] = [];
|
||||
let packetsVersion = 0;
|
||||
|
||||
let newUserMessageId: number | null = null;
|
||||
let newAssistantMessageId: number | null = null;
|
||||
@@ -729,7 +730,6 @@ export function useChatController({
|
||||
if (!packet) {
|
||||
continue;
|
||||
}
|
||||
console.debug("Packet:", JSON.stringify(packet));
|
||||
|
||||
// We've processed initial packets and are starting to stream content.
|
||||
// Transition from 'loading' to 'streaming'.
|
||||
@@ -800,8 +800,8 @@ export function useChatController({
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "obj")) {
|
||||
console.debug("Object packet:", JSON.stringify(packet));
|
||||
packets.push(packet as Packet);
|
||||
packetsVersion++;
|
||||
|
||||
// Check if the packet contains document information
|
||||
const packetObj = (packet as Packet).obj;
|
||||
@@ -859,6 +859,8 @@ export function useChatController({
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetsVersion: packetsVersion,
|
||||
packetCount: packets.length,
|
||||
},
|
||||
],
|
||||
// Pass the latest map state
|
||||
@@ -885,6 +887,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
},
|
||||
{
|
||||
nodeId: initialAssistantNode.nodeId,
|
||||
@@ -894,6 +897,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
stackTrace: stackTrace,
|
||||
errorCode: errorCode,
|
||||
isRetryable: isRetryable,
|
||||
|
||||
@@ -139,6 +139,9 @@ export interface Message {
|
||||
|
||||
// new gen
|
||||
packets: Packet[];
|
||||
// Version counter for efficient memo comparison (increments with each packet)
|
||||
packetsVersion?: number;
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
|
||||
// cached values for easy access
|
||||
documents?: OnyxDocument[] | null;
|
||||
|
||||
@@ -74,6 +74,8 @@ export type RegenerationFactory = (regenerationRequest: {
|
||||
|
||||
export interface AIMessageProps {
|
||||
rawPackets: Packet[];
|
||||
// Version counter for efficient memo comparison (avoids array copying)
|
||||
packetsVersion?: number;
|
||||
chatState: FullChatState;
|
||||
nodeId: number;
|
||||
messageId?: number;
|
||||
@@ -88,8 +90,6 @@ export interface AIMessageProps {
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
// - `rawPackets.length` assumes packets are append-only. Could compare the last
|
||||
// packet or use a shallow comparison if packets can be modified in place.
|
||||
// - `chatState.docs`, `chatState.citations`, and `otherMessagesCanSwitchTo` use
|
||||
// reference equality. Shallow array/object comparison would be more robust if
|
||||
// these are recreated with the same values.
|
||||
@@ -98,7 +98,7 @@ function arePropsEqual(prev: AIMessageProps, next: AIMessageProps): boolean {
|
||||
prev.nodeId === next.nodeId &&
|
||||
prev.messageId === next.messageId &&
|
||||
prev.currentFeedback === next.currentFeedback &&
|
||||
prev.rawPackets.length === next.rawPackets.length &&
|
||||
prev.packetsVersion === next.packetsVersion &&
|
||||
prev.chatState.assistant?.id === next.chatState.assistant?.id &&
|
||||
prev.chatState.docs === next.chatState.docs &&
|
||||
prev.chatState.citations === next.chatState.citations &&
|
||||
|
||||
@@ -11,6 +11,7 @@ import { CitationMap } from "../../interfaces";
|
||||
export enum RenderType {
|
||||
HIGHLIGHT = "highlight",
|
||||
FULL = "full",
|
||||
COMPACT = "compact",
|
||||
}
|
||||
|
||||
export interface FullChatState {
|
||||
@@ -35,6 +36,9 @@ export interface RendererResult {
|
||||
// used for things that should just show text w/o an icon or header
|
||||
// e.g. ReasoningRenderer
|
||||
expandedText?: JSX.Element;
|
||||
|
||||
// Whether this renderer supports compact mode (collapse button shown only when true)
|
||||
supportsCompact?: boolean;
|
||||
}
|
||||
|
||||
export type MessageRenderer<
|
||||
@@ -48,5 +52,7 @@ export type MessageRenderer<
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
/** Whether this is the last step in the timeline (for connector line decisions) */
|
||||
isLastStep?: boolean;
|
||||
children: (result: RendererResult) => JSX.Element;
|
||||
}>;
|
||||
|
||||
@@ -68,10 +68,11 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
|
||||
|
||||
const icon = FiTool;
|
||||
|
||||
if (renderType === RenderType.HIGHLIGHT) {
|
||||
if (renderType === RenderType.COMPACT) {
|
||||
return children({
|
||||
icon,
|
||||
status: status,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{isRunning && `${toolName} running...`}
|
||||
@@ -84,6 +85,7 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
|
||||
return children({
|
||||
icon,
|
||||
status,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="flex flex-col gap-3">
|
||||
{/* File responses */}
|
||||
|
||||
@@ -72,6 +72,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Generating images...",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex flex-col">
|
||||
<div>
|
||||
@@ -89,6 +90,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
status: `Generated ${images.length} image${
|
||||
images.length !== 1 ? "s" : ""
|
||||
}`,
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex flex-col my-1">
|
||||
{images.length > 0 ? (
|
||||
@@ -122,6 +124,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: status,
|
||||
supportsCompact: false,
|
||||
content: <div></div>,
|
||||
});
|
||||
}
|
||||
@@ -131,6 +134,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Generating image...",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<div className="flex gap-0.5">
|
||||
@@ -154,6 +158,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Image generation failed",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-red-600 dark:text-red-400">
|
||||
Image generation failed
|
||||
@@ -166,6 +171,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: `Generated ${images.length} image${images.length > 1 ? "s" : ""}`,
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">
|
||||
Generated {images.length} image
|
||||
@@ -178,6 +184,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Image generation",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">Image generation</div>
|
||||
),
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
"use client";
|
||||
|
||||
import React, { FunctionComponent, useMemo, useCallback } from "react";
|
||||
import { StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState } from "../interfaces";
|
||||
import { TurnGroup, TransformedStep } from "./transformers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
|
||||
import { SvgCheckCircle, SvgStopCircle } from "@opal/icons";
|
||||
import { IconProps } from "@opal/types";
|
||||
import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererResult,
|
||||
} from "./TimelineRendererComponent";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { ParallelTimelineTabs } from "./ParallelTimelineTabs";
|
||||
import { StepContainer } from "./StepContainer";
|
||||
import {
|
||||
useTimelineExpansion,
|
||||
useTimelineMetrics,
|
||||
useTimelineHeader,
|
||||
} from "@/app/chat/message/messageComponents/timeline/hooks";
|
||||
import {
|
||||
isResearchAgentPackets,
|
||||
stepSupportsCompact,
|
||||
} from "@/app/chat/message/messageComponents/timeline/packetHelpers";
|
||||
import {
|
||||
StreamingHeader,
|
||||
CollapsedHeader,
|
||||
ExpandedHeader,
|
||||
StoppedHeader,
|
||||
ParallelStreamingHeader,
|
||||
} from "@/app/chat/message/messageComponents/timeline/headers";
|
||||
|
||||
// =============================================================================
|
||||
// TimelineStep Component - Memoized to prevent re-renders
|
||||
// =============================================================================
|
||||
|
||||
interface TimelineStepProps {
|
||||
step: TransformedStep;
|
||||
chatState: FullChatState;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
isLastStep: boolean;
|
||||
isFirstStep: boolean;
|
||||
isSingleStep: boolean;
|
||||
}
|
||||
|
||||
//will be removed on cleanup
|
||||
const noopCallback = () => {};
|
||||
|
||||
const TimelineStep = React.memo(function TimelineStep({
|
||||
step,
|
||||
chatState,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
isLastStep,
|
||||
isFirstStep,
|
||||
isSingleStep,
|
||||
}: TimelineStepProps) {
|
||||
// Stable render callback - doesn't need to change between renders
|
||||
const renderStep = useCallback(
|
||||
({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
isLastStep: rendererIsLastStep,
|
||||
supportsCompact,
|
||||
}: TimelineRendererResult) =>
|
||||
isResearchAgentPackets(step.packets) ? (
|
||||
content
|
||||
) : (
|
||||
<StepContainer
|
||||
stepIcon={icon as FunctionComponent<IconProps> | undefined}
|
||||
header={status}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={onToggle}
|
||||
collapsible={true}
|
||||
supportsCompact={supportsCompact}
|
||||
isLastStep={rendererIsLastStep}
|
||||
isFirstStep={isFirstStep}
|
||||
hideHeader={isSingleStep}
|
||||
>
|
||||
{content}
|
||||
</StepContainer>
|
||||
),
|
||||
[step.packets, isFirstStep, isSingleStep]
|
||||
);
|
||||
|
||||
return (
|
||||
<TimelineRendererComponent
|
||||
packets={step.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopCallback}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={true}
|
||||
isLastStep={isLastStep}
|
||||
>
|
||||
{renderStep}
|
||||
</TimelineRendererComponent>
|
||||
);
|
||||
});
|
||||
|
||||
// =============================================================================
|
||||
// Main Component
|
||||
// =============================================================================
|
||||
|
||||
export interface AgentTimelineProps {
|
||||
/** Turn groups from usePacketProcessor */
|
||||
turnGroups: TurnGroup[];
|
||||
/** Chat state for rendering content */
|
||||
chatState: FullChatState;
|
||||
/** Whether the stop packet has been seen */
|
||||
stopPacketSeen?: boolean;
|
||||
/** Reason for stopping (if stopped) */
|
||||
stopReason?: StopReason;
|
||||
/** Whether final answer is coming (affects last connector) */
|
||||
finalAnswerComing?: boolean;
|
||||
/** Whether there is display content after timeline */
|
||||
hasDisplayContent?: boolean;
|
||||
/** Content to render after timeline (final message + toolbar) - slot pattern */
|
||||
children?: React.ReactNode;
|
||||
/** Whether the timeline is collapsible */
|
||||
collapsible?: boolean;
|
||||
/** Title of the button to toggle the timeline */
|
||||
buttonTitle?: string;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
/** Test ID for e2e testing */
|
||||
"data-testid"?: string;
|
||||
/** Unique tool names (pre-computed for performance) */
|
||||
uniqueToolNames?: string[];
|
||||
}
|
||||
|
||||
export function AgentTimeline({
|
||||
turnGroups,
|
||||
chatState,
|
||||
stopPacketSeen = false,
|
||||
stopReason,
|
||||
finalAnswerComing = false,
|
||||
hasDisplayContent = false,
|
||||
collapsible = true,
|
||||
buttonTitle,
|
||||
className,
|
||||
"data-testid": testId,
|
||||
uniqueToolNames = [],
|
||||
}: AgentTimelineProps) {
|
||||
// Header text and state flags
|
||||
const { headerText, hasPackets, userStopped } = useTimelineHeader(
|
||||
turnGroups,
|
||||
stopReason
|
||||
);
|
||||
|
||||
// Memoized metrics derived from turn groups
|
||||
const {
|
||||
totalSteps,
|
||||
isSingleStep,
|
||||
uniqueTools,
|
||||
lastTurnGroup,
|
||||
lastStep,
|
||||
lastStepIsResearchAgent,
|
||||
lastStepSupportsCompact,
|
||||
} = useTimelineMetrics(turnGroups, uniqueToolNames, userStopped);
|
||||
|
||||
// Expansion state management
|
||||
const { isExpanded, handleToggle, parallelActiveTab, setParallelActiveTab } =
|
||||
useTimelineExpansion(stopPacketSeen, lastTurnGroup);
|
||||
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderContentOnly = useCallback(
|
||||
({ content }: TimelineRendererResult) => content,
|
||||
[]
|
||||
);
|
||||
|
||||
// Parallel step analysis for collapsed streaming view
|
||||
const parallelActiveStep = useMemo(() => {
|
||||
if (!lastTurnGroup?.isParallel) return null;
|
||||
return (
|
||||
lastTurnGroup.steps.find((s) => s.key === parallelActiveTab) ??
|
||||
lastTurnGroup.steps[0]
|
||||
);
|
||||
}, [lastTurnGroup, parallelActiveTab]);
|
||||
|
||||
const parallelActiveStepSupportsCompact = useMemo(() => {
|
||||
if (!parallelActiveStep) return false;
|
||||
return (
|
||||
stepSupportsCompact(parallelActiveStep.packets) &&
|
||||
!isResearchAgentPackets(parallelActiveStep.packets)
|
||||
);
|
||||
}, [parallelActiveStep]);
|
||||
|
||||
// Collapsed streaming: show compact content below header
|
||||
const showCollapsedCompact =
|
||||
!stopPacketSeen &&
|
||||
!isExpanded &&
|
||||
lastStep &&
|
||||
!lastTurnGroup?.isParallel &&
|
||||
!lastStepIsResearchAgent &&
|
||||
lastStepSupportsCompact;
|
||||
|
||||
// Parallel tabs in header only when collapsed (expanded view has tabs in content)
|
||||
const showParallelTabs =
|
||||
!stopPacketSeen &&
|
||||
!isExpanded &&
|
||||
lastTurnGroup?.isParallel &&
|
||||
lastTurnGroup.steps.length > 0;
|
||||
|
||||
// Collapsed parallel compact content
|
||||
const showCollapsedParallel =
|
||||
showParallelTabs && !isExpanded && parallelActiveStepSupportsCompact;
|
||||
|
||||
// Done indicator conditions
|
||||
const showDoneIndicator =
|
||||
stopPacketSeen && isExpanded && !userStopped && !lastStepIsResearchAgent;
|
||||
|
||||
// Header selection based on state
|
||||
const renderHeader = () => {
|
||||
if (!stopPacketSeen) {
|
||||
if (showParallelTabs && lastTurnGroup) {
|
||||
return (
|
||||
<ParallelStreamingHeader
|
||||
steps={lastTurnGroup.steps}
|
||||
activeTab={parallelActiveTab}
|
||||
onTabChange={setParallelActiveTab}
|
||||
collapsible={collapsible}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<StreamingHeader
|
||||
headerText={headerText}
|
||||
collapsible={collapsible}
|
||||
buttonTitle={buttonTitle}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (userStopped) {
|
||||
return (
|
||||
<StoppedHeader
|
||||
totalSteps={totalSteps}
|
||||
collapsible={collapsible}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isExpanded) {
|
||||
return (
|
||||
<CollapsedHeader
|
||||
uniqueTools={uniqueTools}
|
||||
totalSteps={totalSteps}
|
||||
collapsible={collapsible}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <ExpandedHeader collapsible={collapsible} onToggle={handleToggle} />;
|
||||
};
|
||||
|
||||
// Empty state: no packets, still streaming
|
||||
if (!hasPackets && !hasDisplayContent) {
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
<div className="flex w-full h-full items-center px-2">
|
||||
<Text
|
||||
as="p"
|
||||
mainUiAction
|
||||
text03
|
||||
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
|
||||
>
|
||||
{headerText}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Display content only (no timeline steps)
|
||||
if (hasDisplayContent && !hasPackets) {
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
{/* Header row */}
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-full h-full items-center justify-between px-2",
|
||||
(!stopPacketSeen || userStopped || isExpanded) &&
|
||||
"bg-background-tint-00 rounded-t-12",
|
||||
!isExpanded &&
|
||||
!showCollapsedCompact &&
|
||||
!showCollapsedParallel &&
|
||||
"rounded-b-12"
|
||||
)}
|
||||
>
|
||||
{renderHeader()}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Collapsed streaming view - single step compact mode */}
|
||||
{showCollapsedCompact && lastStep && (
|
||||
<div className="flex w-full">
|
||||
<div className="w-9" />
|
||||
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
|
||||
<TimelineRendererComponent
|
||||
key={`${lastStep.key}-compact`}
|
||||
packets={lastStep.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={true}
|
||||
stopPacketSeen={false}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={false}
|
||||
isLastStep={true}
|
||||
>
|
||||
{renderContentOnly}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Collapsed streaming view - parallel tools compact mode */}
|
||||
{showCollapsedParallel && parallelActiveStep && (
|
||||
<div className="flex w-full">
|
||||
<div className="w-9" />
|
||||
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
|
||||
<TimelineRendererComponent
|
||||
key={`${parallelActiveStep.key}-compact`}
|
||||
packets={parallelActiveStep.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={true}
|
||||
stopPacketSeen={false}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={false}
|
||||
isLastStep={true}
|
||||
>
|
||||
{renderContentOnly}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Expanded timeline view */}
|
||||
{isExpanded && (
|
||||
<div className="w-full">
|
||||
{turnGroups.map((turnGroup, turnIdx) =>
|
||||
turnGroup.isParallel ? (
|
||||
<ParallelTimelineTabs
|
||||
key={turnGroup.turnIndex}
|
||||
turnGroup={turnGroup}
|
||||
chatState={chatState}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastTurnGroup={turnIdx === turnGroups.length - 1}
|
||||
/>
|
||||
) : (
|
||||
turnGroup.steps.map((step, stepIdx) => {
|
||||
const stepIsLast =
|
||||
turnIdx === turnGroups.length - 1 &&
|
||||
stepIdx === turnGroup.steps.length - 1 &&
|
||||
!showDoneIndicator &&
|
||||
!userStopped;
|
||||
const stepIsFirst = turnIdx === 0 && stepIdx === 0;
|
||||
|
||||
return (
|
||||
<TimelineStep
|
||||
key={step.key}
|
||||
step={step}
|
||||
chatState={chatState}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastStep={stepIsLast}
|
||||
isFirstStep={stepIsFirst}
|
||||
isSingleStep={isSingleStep}
|
||||
/>
|
||||
);
|
||||
})
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Done indicator */}
|
||||
{stopPacketSeen && isExpanded && !userStopped && (
|
||||
<StepContainer
|
||||
stepIcon={SvgCheckCircle}
|
||||
header="Done"
|
||||
isLastStep={true}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{null}
|
||||
</StepContainer>
|
||||
)}
|
||||
|
||||
{/* Stopped indicator */}
|
||||
{stopPacketSeen && isExpanded && userStopped && (
|
||||
<StepContainer
|
||||
stepIcon={SvgStopCircle}
|
||||
header="Stopped"
|
||||
isLastStep={true}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{null}
|
||||
</StepContainer>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default AgentTimeline;
|
||||
@@ -0,0 +1,130 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useState,
|
||||
useMemo,
|
||||
useCallback,
|
||||
FunctionComponent,
|
||||
} from "react";
|
||||
import { StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState } from "../interfaces";
|
||||
import { TurnGroup } from "./transformers";
|
||||
import { getToolName, getToolIcon } from "../toolDisplayHelpers";
|
||||
import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererResult,
|
||||
} from "./TimelineRendererComponent";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { SvgBranch } from "@opal/icons";
|
||||
import { StepContainer } from "./StepContainer";
|
||||
import { isResearchAgentPackets } from "@/app/chat/message/messageComponents/timeline/packetHelpers";
|
||||
import { IconProps } from "@/components/icons/icons";
|
||||
|
||||
export interface ParallelTimelineTabsProps {
|
||||
/** Turn group containing parallel steps */
|
||||
turnGroup: TurnGroup;
|
||||
/** Chat state for rendering content */
|
||||
chatState: FullChatState;
|
||||
/** Whether the stop packet has been seen */
|
||||
stopPacketSeen: boolean;
|
||||
/** Reason for stopping (if stopped) */
|
||||
stopReason?: StopReason;
|
||||
/** Whether this is the last turn group (affects connector line) */
|
||||
isLastTurnGroup: boolean;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ParallelTimelineTabs({
|
||||
turnGroup,
|
||||
chatState,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
isLastTurnGroup,
|
||||
className,
|
||||
}: ParallelTimelineTabsProps) {
|
||||
const [activeTab, setActiveTab] = useState(turnGroup.steps[0]?.key ?? "");
|
||||
|
||||
// Find the active step based on selected tab
|
||||
const activeStep = useMemo(
|
||||
() => turnGroup.steps.find((step) => step.key === activeTab),
|
||||
[turnGroup.steps, activeTab]
|
||||
);
|
||||
//will be removed on cleanup
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderTabContent = useCallback(
|
||||
({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
isLastStep,
|
||||
}: TimelineRendererResult) =>
|
||||
isResearchAgentPackets(activeStep?.packets ?? []) ? (
|
||||
content
|
||||
) : (
|
||||
<StepContainer
|
||||
stepIcon={icon as FunctionComponent<IconProps> | undefined}
|
||||
header={status}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={onToggle}
|
||||
collapsible={true}
|
||||
isLastStep={isLastStep}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{content}
|
||||
</StepContainer>
|
||||
),
|
||||
[activeStep?.packets]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tabs value={activeTab} onValueChange={setActiveTab}>
|
||||
<div className="flex flex-col w-full gap-1">
|
||||
<div className="flex w-full">
|
||||
{/* Left column: Icon + connector line */}
|
||||
<div className="flex flex-col items-center w-9 pt-2">
|
||||
<div className="size-4 flex items-center justify-center stroke-text-02">
|
||||
<SvgBranch className="w-4 h-4" />
|
||||
</div>
|
||||
{/* Connector line */}
|
||||
<div className="w-px flex-1 bg-border-01" />
|
||||
</div>
|
||||
|
||||
{/* Right column: Tabs */}
|
||||
<div className="flex-1">
|
||||
<Tabs.List variant="pill">
|
||||
{turnGroup.steps.map((step) => (
|
||||
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
|
||||
<span className="flex items-center gap-1.5">
|
||||
{getToolIcon(step.packets)}
|
||||
{getToolName(step.packets)}
|
||||
</span>
|
||||
</Tabs.Trigger>
|
||||
))}
|
||||
</Tabs.List>
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<TimelineRendererComponent
|
||||
key={activeTab}
|
||||
packets={activeStep?.packets ?? []}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={true}
|
||||
isLastStep={isLastTurnGroup}
|
||||
>
|
||||
{renderTabContent}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
</Tabs>
|
||||
);
|
||||
}
|
||||
|
||||
export default ParallelTimelineTabs;
|
||||
@@ -0,0 +1,108 @@
|
||||
import React, { FunctionComponent } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { IconProps } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StepContainerProps {
|
||||
/** Main content */
|
||||
children?: React.ReactNode;
|
||||
/** Step icon component */
|
||||
stepIcon?: FunctionComponent<IconProps>;
|
||||
/** Header left slot */
|
||||
header?: React.ReactNode;
|
||||
/** Button title for toggle */
|
||||
buttonTitle?: string;
|
||||
/** Controlled expanded state */
|
||||
isExpanded?: boolean;
|
||||
/** Toggle callback */
|
||||
onToggle?: () => void;
|
||||
/** Whether collapse control is shown */
|
||||
collapsible?: boolean;
|
||||
/** Collapse button shown only when renderer supports compact mode */
|
||||
supportsCompact?: boolean;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
/** Last step (no bottom connector) */
|
||||
isLastStep?: boolean;
|
||||
/** First step (top padding instead of connector) */
|
||||
isFirstStep?: boolean;
|
||||
/** Hide header (single-step timelines) */
|
||||
hideHeader?: boolean;
|
||||
}
|
||||
|
||||
/** Visual wrapper for timeline steps - icon, connector line, header, and content */
|
||||
export function StepContainer({
|
||||
children,
|
||||
stepIcon: StepIconComponent,
|
||||
header,
|
||||
buttonTitle,
|
||||
isExpanded = true,
|
||||
onToggle,
|
||||
collapsible = true,
|
||||
supportsCompact = false,
|
||||
isLastStep = false,
|
||||
isFirstStep = false,
|
||||
className,
|
||||
hideHeader = false,
|
||||
}: StepContainerProps) {
|
||||
const showCollapseControls = collapsible && supportsCompact && onToggle;
|
||||
|
||||
return (
|
||||
<div className={cn("flex w-full", className)}>
|
||||
<div
|
||||
className={cn("flex flex-col items-center w-9", isFirstStep && "pt-2")}
|
||||
>
|
||||
{/* Icon */}
|
||||
{!hideHeader && StepIconComponent && (
|
||||
<div className="py-1">
|
||||
<StepIconComponent className="size-4 stroke-text-02" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Connector line */}
|
||||
{!isLastStep && <div className="w-px flex-1 bg-border-01" />}
|
||||
</div>
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"w-full bg-background-tint-00",
|
||||
isLastStep && "rounded-b-12"
|
||||
)}
|
||||
>
|
||||
{!hideHeader && (
|
||||
<div className="flex items-center justify-between px-2">
|
||||
{header && (
|
||||
<Text as="p" mainUiMuted text03>
|
||||
{header}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{showCollapseControls &&
|
||||
(buttonTitle ? (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
>
|
||||
{buttonTitle}
|
||||
</Button>
|
||||
) : (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="px-2 pb-2">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default StepContainer;
|
||||
@@ -0,0 +1,116 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, JSX } from "react";
|
||||
import { Packet, StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState, RenderType, RendererResult } from "../interfaces";
|
||||
import { findRenderer } from "../renderMessageComponent";
|
||||
|
||||
/** Extended result that includes collapse state */
|
||||
export interface TimelineRendererResult extends RendererResult {
|
||||
/** Current expanded state */
|
||||
isExpanded: boolean;
|
||||
/** Toggle callback */
|
||||
onToggle: () => void;
|
||||
/** Current render type */
|
||||
renderType: RenderType;
|
||||
/** Whether this is the last step (passed through from props) */
|
||||
isLastStep: boolean;
|
||||
}
|
||||
|
||||
export interface TimelineRendererComponentProps {
|
||||
/** Packets to render */
|
||||
packets: Packet[];
|
||||
/** Chat state for rendering */
|
||||
chatState: FullChatState;
|
||||
/** Completion callback */
|
||||
onComplete: () => void;
|
||||
/** Whether to animate streaming */
|
||||
animate: boolean;
|
||||
/** Whether stop packet has been seen */
|
||||
stopPacketSeen: boolean;
|
||||
/** Reason for stopping */
|
||||
stopReason?: StopReason;
|
||||
/** Initial expanded state */
|
||||
defaultExpanded?: boolean;
|
||||
/** Whether this is the last step in the timeline (for connector line decisions) */
|
||||
isLastStep?: boolean;
|
||||
/** Children render function - receives extended result with collapse state */
|
||||
children: (result: TimelineRendererResult) => JSX.Element;
|
||||
}
|
||||
|
||||
// Custom comparison function to prevent unnecessary re-renders
|
||||
// Only re-render if meaningful changes occur
|
||||
function arePropsEqual(
|
||||
prev: TimelineRendererComponentProps,
|
||||
next: TimelineRendererComponentProps
|
||||
): boolean {
|
||||
return (
|
||||
prev.packets.length === next.packets.length &&
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.isLastStep === next.isLastStep &&
|
||||
prev.defaultExpanded === next.defaultExpanded
|
||||
// Skipping chatState (memoized upstream)
|
||||
);
|
||||
}
|
||||
|
||||
export const TimelineRendererComponent = React.memo(
|
||||
function TimelineRendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
defaultExpanded = true,
|
||||
isLastStep,
|
||||
children,
|
||||
}: TimelineRendererComponentProps) {
|
||||
const [isExpanded, setIsExpanded] = useState(defaultExpanded);
|
||||
const handleToggle = () => setIsExpanded((prev) => !prev);
|
||||
const RendererFn = findRenderer({ packets });
|
||||
const renderType = isExpanded ? RenderType.FULL : RenderType.COMPACT;
|
||||
|
||||
if (!RendererFn) {
|
||||
return children({
|
||||
icon: null,
|
||||
status: null,
|
||||
content: <></>,
|
||||
supportsCompact: false,
|
||||
isExpanded,
|
||||
onToggle: handleToggle,
|
||||
renderType,
|
||||
isLastStep: isLastStep ?? true,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={renderType}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastStep={isLastStep}
|
||||
>
|
||||
{({ icon, status, content, expandedText, supportsCompact }) =>
|
||||
children({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
expandedText,
|
||||
supportsCompact,
|
||||
isExpanded,
|
||||
onToggle: handleToggle,
|
||||
renderType,
|
||||
isLastStep: isLastStep ?? true,
|
||||
})
|
||||
}
|
||||
</RendererFn>
|
||||
);
|
||||
},
|
||||
arePropsEqual
|
||||
);
|
||||
@@ -0,0 +1,49 @@
|
||||
import React from "react";
|
||||
import { SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { UniqueTool } from "@/app/chat/message/messageComponents/timeline/hooks";
|
||||
|
||||
export interface CollapsedHeaderProps {
|
||||
uniqueTools: UniqueTool[];
|
||||
totalSteps: number;
|
||||
collapsible: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when completed + collapsed - tools summary + step count */
|
||||
export const CollapsedHeader = React.memo(function CollapsedHeader({
|
||||
uniqueTools,
|
||||
totalSteps,
|
||||
collapsible,
|
||||
onToggle,
|
||||
}: CollapsedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<div className="flex items-center gap-2">
|
||||
{uniqueTools.map((tool) => (
|
||||
<div
|
||||
key={tool.key}
|
||||
className="inline-flex items-center gap-1 rounded-08 p-1 bg-background-tint-02"
|
||||
>
|
||||
{tool.icon}
|
||||
<Text as="span" secondaryBody text04>
|
||||
{tool.name}
|
||||
</Text>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
{collapsible && (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={SvgExpand}
|
||||
aria-label="Expand timeline"
|
||||
aria-expanded={false}
|
||||
>
|
||||
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,32 @@
|
||||
import React from "react";
|
||||
import { SvgFold } from "@opal/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface ExpandedHeaderProps {
|
||||
collapsible: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when completed + expanded */
|
||||
export const ExpandedHeader = React.memo(function ExpandedHeader({
|
||||
collapsible,
|
||||
onToggle,
|
||||
}: ExpandedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text as="p" mainUiAction text03>
|
||||
Thought for some time
|
||||
</Text>
|
||||
{collapsible && (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={SvgFold}
|
||||
aria-label="Collapse timeline"
|
||||
aria-expanded={true}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,53 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { TurnGroup } from "../transformers";
|
||||
import { getToolIcon, getToolName } from "../../toolDisplayHelpers";
|
||||
|
||||
export interface ParallelStreamingHeaderProps {
|
||||
steps: TurnGroup["steps"];
|
||||
activeTab: string;
|
||||
onTabChange: (tab: string) => void;
|
||||
collapsible: boolean;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header during streaming with parallel tools - tabs only */
|
||||
export const ParallelStreamingHeader = React.memo(
|
||||
function ParallelStreamingHeader({
|
||||
steps,
|
||||
activeTab,
|
||||
onTabChange,
|
||||
collapsible,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: ParallelStreamingHeaderProps) {
|
||||
return (
|
||||
<Tabs value={activeTab} onValueChange={onTabChange}>
|
||||
<div className="flex items-center justify-between w-full gap-2">
|
||||
<Tabs.List variant="pill">
|
||||
{steps.map((step) => (
|
||||
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
|
||||
<span className="flex items-center gap-1.5">
|
||||
{getToolIcon(step.packets)}
|
||||
{getToolName(step.packets)}
|
||||
</span>
|
||||
</Tabs.Trigger>
|
||||
))}
|
||||
</Tabs.List>
|
||||
{collapsible && (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Tabs>
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -0,0 +1,38 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StoppedHeaderProps {
|
||||
totalSteps: number;
|
||||
collapsible: boolean;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when user stopped/cancelled */
|
||||
export const StoppedHeader = React.memo(function StoppedHeader({
|
||||
totalSteps,
|
||||
collapsible,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: StoppedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text as="p" mainUiAction text03>
|
||||
Stopped Thinking
|
||||
</Text>
|
||||
{collapsible && (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
>
|
||||
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,54 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StreamingHeaderProps {
|
||||
headerText: string;
|
||||
collapsible: boolean;
|
||||
buttonTitle?: string;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header during streaming - shimmer text with current activity */
|
||||
export const StreamingHeader = React.memo(function StreamingHeader({
|
||||
headerText,
|
||||
collapsible,
|
||||
buttonTitle,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: StreamingHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text
|
||||
as="p"
|
||||
mainUiAction
|
||||
text03
|
||||
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
|
||||
>
|
||||
{headerText}
|
||||
</Text>
|
||||
{collapsible &&
|
||||
(buttonTitle ? (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-expanded={isExpanded}
|
||||
>
|
||||
{buttonTitle}
|
||||
</Button>
|
||||
) : (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,14 @@
|
||||
export { StreamingHeader } from "./StreamingHeader";
|
||||
export type { StreamingHeaderProps } from "./StreamingHeader";
|
||||
|
||||
export { CollapsedHeader } from "./CollapsedHeader";
|
||||
export type { CollapsedHeaderProps } from "./CollapsedHeader";
|
||||
|
||||
export { ExpandedHeader } from "./ExpandedHeader";
|
||||
export type { ExpandedHeaderProps } from "./ExpandedHeader";
|
||||
|
||||
export { StoppedHeader } from "./StoppedHeader";
|
||||
export type { StoppedHeaderProps } from "./StoppedHeader";
|
||||
|
||||
export { ParallelStreamingHeader } from "./ParallelStreamingHeader";
|
||||
export type { ParallelStreamingHeaderProps } from "./ParallelStreamingHeader";
|
||||
@@ -0,0 +1,11 @@
|
||||
export { useTimelineExpansion } from "./useTimelineExpansion";
|
||||
export type { TimelineExpansionState } from "./useTimelineExpansion";
|
||||
|
||||
export { useTimelineMetrics } from "./useTimelineMetrics";
|
||||
export type { TimelineMetrics, UniqueTool } from "./useTimelineMetrics";
|
||||
|
||||
export { usePacketProcessor } from "./usePacketProcessor";
|
||||
export type { UsePacketProcessorResult } from "./usePacketProcessor";
|
||||
|
||||
export { useTimelineHeader } from "./useTimelineHeader";
|
||||
export type { TimelineHeaderResult } from "./useTimelineHeader";
|
||||
@@ -0,0 +1,439 @@
|
||||
import {
|
||||
Packet,
|
||||
PacketType,
|
||||
StreamingCitation,
|
||||
StopReason,
|
||||
CitationInfo,
|
||||
SearchToolDocumentsDelta,
|
||||
FetchToolDocuments,
|
||||
TopLevelBranching,
|
||||
Stop,
|
||||
SearchToolStart,
|
||||
CustomToolStart,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { CitationMap } from "@/app/chat/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import {
|
||||
isActualToolCallPacket,
|
||||
isToolPacket,
|
||||
isDisplayPacket,
|
||||
} from "@/app/chat/services/packetUtils";
|
||||
import { parseToolKey } from "@/app/chat/message/messageComponents/toolDisplayHelpers";
|
||||
|
||||
// Re-export parseToolKey for consumers that import from this module
|
||||
export { parseToolKey };
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface ProcessorState {
|
||||
nodeId: number;
|
||||
lastProcessedIndex: number;
|
||||
|
||||
// Citations
|
||||
citations: StreamingCitation[];
|
||||
seenCitationDocIds: Set<string>;
|
||||
citationMap: CitationMap;
|
||||
|
||||
// Documents
|
||||
documentMap: Map<string, OnyxDocument>;
|
||||
|
||||
// Packet grouping
|
||||
groupedPacketsMap: Map<string, Packet[]>;
|
||||
seenGroupKeys: Set<string>;
|
||||
groupKeysWithSectionEnd: Set<string>;
|
||||
expectedBranches: Map<number, number>;
|
||||
|
||||
// Pre-categorized groups (populated during packet processing)
|
||||
toolGroupKeys: Set<string>;
|
||||
displayGroupKeys: Set<string>;
|
||||
|
||||
// Unique tool names tracking (populated during packet processing)
|
||||
uniqueToolNames: Set<string>;
|
||||
|
||||
// Streaming status
|
||||
finalAnswerComing: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason: StopReason | undefined;
|
||||
|
||||
// Result arrays (built at end of processPackets)
|
||||
toolGroups: GroupedPacket[];
|
||||
potentialDisplayGroups: GroupedPacket[];
|
||||
uniqueToolNamesArray: string[];
|
||||
}
|
||||
|
||||
export interface GroupedPacket {
|
||||
turn_index: number;
|
||||
tab_index: number;
|
||||
packets: Packet[];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// State Creation
|
||||
// ============================================================================
|
||||
|
||||
export function createInitialState(nodeId: number): ProcessorState {
|
||||
return {
|
||||
nodeId,
|
||||
lastProcessedIndex: 0,
|
||||
citations: [],
|
||||
seenCitationDocIds: new Set(),
|
||||
citationMap: {},
|
||||
documentMap: new Map(),
|
||||
groupedPacketsMap: new Map(),
|
||||
seenGroupKeys: new Set(),
|
||||
groupKeysWithSectionEnd: new Set(),
|
||||
expectedBranches: new Map(),
|
||||
toolGroupKeys: new Set(),
|
||||
displayGroupKeys: new Set(),
|
||||
uniqueToolNames: new Set(),
|
||||
finalAnswerComing: false,
|
||||
stopPacketSeen: false,
|
||||
stopReason: undefined,
|
||||
toolGroups: [],
|
||||
potentialDisplayGroups: [],
|
||||
uniqueToolNamesArray: [],
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
function getGroupKey(packet: Packet): string {
|
||||
const turnIndex = packet.placement.turn_index;
|
||||
const tabIndex = packet.placement.tab_index ?? 0;
|
||||
return `${turnIndex}-${tabIndex}`;
|
||||
}
|
||||
|
||||
function injectSectionEnd(state: ProcessorState, groupKey: string): void {
|
||||
if (state.groupKeysWithSectionEnd.has(groupKey)) {
|
||||
return; // Already has SECTION_END
|
||||
}
|
||||
|
||||
const { turn_index, tab_index } = parseToolKey(groupKey);
|
||||
|
||||
const syntheticPacket: Packet = {
|
||||
placement: { turn_index, tab_index },
|
||||
obj: { type: PacketType.SECTION_END },
|
||||
};
|
||||
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
if (existingGroup) {
|
||||
existingGroup.push(syntheticPacket);
|
||||
}
|
||||
state.groupKeysWithSectionEnd.add(groupKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Content packet types that indicate a group has meaningful content to display
|
||||
*/
|
||||
const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.MESSAGE_START,
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
PacketType.REASONING_START,
|
||||
PacketType.DEEP_RESEARCH_PLAN_START,
|
||||
PacketType.RESEARCH_AGENT_START,
|
||||
]);
|
||||
|
||||
function hasContentPackets(packets: Packet[]): boolean {
|
||||
return packets.some((packet) =>
|
||||
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract tool name from a packet for unique tool tracking.
|
||||
* Returns null for non-tool packets.
|
||||
*/
|
||||
function getToolNameFromPacket(packet: Packet): string | null {
|
||||
switch (packet.obj.type) {
|
||||
case PacketType.SEARCH_TOOL_START: {
|
||||
const searchPacket = packet.obj as SearchToolStart;
|
||||
return searchPacket.is_internet_search ? "Web Search" : "Internal Search";
|
||||
}
|
||||
case PacketType.PYTHON_TOOL_START:
|
||||
return "Code Interpreter";
|
||||
case PacketType.FETCH_TOOL_START:
|
||||
return "Open URLs";
|
||||
case PacketType.CUSTOM_TOOL_START: {
|
||||
const customPacket = packet.obj as CustomToolStart;
|
||||
return customPacket.tool_name || "Custom Tool";
|
||||
}
|
||||
case PacketType.IMAGE_GENERATION_TOOL_START:
|
||||
return "Generate Image";
|
||||
case PacketType.DEEP_RESEARCH_PLAN_START:
|
||||
return "Generate plan";
|
||||
case PacketType.RESEARCH_AGENT_START:
|
||||
return "Research agent";
|
||||
case PacketType.REASONING_START:
|
||||
return "Thinking";
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Packet types that indicate final answer content is coming
|
||||
*/
|
||||
const FINAL_ANSWER_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.MESSAGE_START,
|
||||
PacketType.MESSAGE_DELTA,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_DELTA,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_DELTA,
|
||||
]);
|
||||
|
||||
// ============================================================================
|
||||
// Packet Handlers
|
||||
// ============================================================================
|
||||
|
||||
function handleTopLevelBranching(state: ProcessorState, packet: Packet): void {
|
||||
const branchingPacket = packet.obj as TopLevelBranching;
|
||||
state.expectedBranches.set(
|
||||
packet.placement.turn_index,
|
||||
branchingPacket.num_parallel_branches
|
||||
);
|
||||
}
|
||||
|
||||
function handleTurnTransition(state: ProcessorState, packet: Packet): void {
|
||||
const currentTurnIndex = packet.placement.turn_index;
|
||||
|
||||
// Get all previous turn indices from seen group keys
|
||||
const previousTurnIndices = new Set(
|
||||
Array.from(state.seenGroupKeys).map((key) => parseToolKey(key).turn_index)
|
||||
);
|
||||
|
||||
const isNewTurnIndex = !previousTurnIndices.has(currentTurnIndex);
|
||||
|
||||
// If we see a new turn_index (not just tab_index), inject SECTION_END for previous groups
|
||||
if (isNewTurnIndex && state.seenGroupKeys.size > 0) {
|
||||
state.seenGroupKeys.forEach((prevGroupKey) => {
|
||||
if (!state.groupKeysWithSectionEnd.has(prevGroupKey)) {
|
||||
injectSectionEnd(state, prevGroupKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCitationPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type !== PacketType.CITATION_INFO) {
|
||||
return;
|
||||
}
|
||||
|
||||
const citationInfo = packet.obj as CitationInfo;
|
||||
|
||||
// Add to citation map immediately for rendering
|
||||
state.citationMap[citationInfo.citation_number] = citationInfo.document_id;
|
||||
|
||||
// Also add to citations array for CitedSourcesToggle (deduplicated)
|
||||
if (!state.seenCitationDocIds.has(citationInfo.document_id)) {
|
||||
state.seenCitationDocIds.add(citationInfo.document_id);
|
||||
state.citations.push({
|
||||
citation_num: citationInfo.citation_number,
|
||||
document_id: citationInfo.document_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDocumentPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA) {
|
||||
const docDelta = packet.obj as SearchToolDocumentsDelta;
|
||||
if (docDelta.documents) {
|
||||
for (const doc of docDelta.documents) {
|
||||
if (doc.document_id) {
|
||||
state.documentMap.set(doc.document_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (packet.obj.type === PacketType.FETCH_TOOL_DOCUMENTS) {
|
||||
const fetchDocuments = packet.obj as FetchToolDocuments;
|
||||
if (fetchDocuments.documents) {
|
||||
for (const doc of fetchDocuments.documents) {
|
||||
if (doc.document_id) {
|
||||
state.documentMap.set(doc.document_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamingStatusPacket(
|
||||
state: ProcessorState,
|
||||
packet: Packet
|
||||
): void {
|
||||
// Check if final answer is coming
|
||||
if (FINAL_ANSWER_PACKET_TYPES_SET.has(packet.obj.type as PacketType)) {
|
||||
state.finalAnswerComing = true;
|
||||
}
|
||||
}
|
||||
|
||||
function handleStopPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type !== PacketType.STOP || state.stopPacketSeen) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.stopPacketSeen = true;
|
||||
|
||||
// Extract and store the stop reason
|
||||
const stopPacket = packet.obj as Stop;
|
||||
state.stopReason = stopPacket.stop_reason;
|
||||
|
||||
// Inject SECTION_END for all group keys that don't have one
|
||||
state.seenGroupKeys.forEach((groupKey) => {
|
||||
if (!state.groupKeysWithSectionEnd.has(groupKey)) {
|
||||
injectSectionEnd(state, groupKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleToolAfterMessagePacket(
|
||||
state: ProcessorState,
|
||||
packet: Packet
|
||||
): void {
|
||||
// Handles case where we get a Message packet from Claude, and then tool
|
||||
// calling packets. We use isActualToolCallPacket instead of isToolPacket
|
||||
// to exclude reasoning packets - reasoning is just the model thinking,
|
||||
// not an actual tool call that would produce new content.
|
||||
if (
|
||||
state.finalAnswerComing &&
|
||||
!state.stopPacketSeen &&
|
||||
isActualToolCallPacket(packet)
|
||||
) {
|
||||
state.finalAnswerComing = false;
|
||||
}
|
||||
}
|
||||
|
||||
function addPacketToGroup(
|
||||
state: ProcessorState,
|
||||
packet: Packet,
|
||||
groupKey: string
|
||||
): void {
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
if (existingGroup) {
|
||||
existingGroup.push(packet);
|
||||
} else {
|
||||
state.groupedPacketsMap.set(groupKey, [packet]);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Processing Function
|
||||
// ============================================================================
|
||||
|
||||
function processPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (!packet) return;
|
||||
|
||||
// Handle TopLevelBranching packets - these tell us how many parallel branches to expect
|
||||
if (packet.obj.type === PacketType.TOP_LEVEL_BRANCHING) {
|
||||
handleTopLevelBranching(state, packet);
|
||||
// Don't add this packet to any group, it's just metadata
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle turn transitions (inject SECTION_END for previous groups)
|
||||
handleTurnTransition(state, packet);
|
||||
|
||||
// Track group key
|
||||
const groupKey = getGroupKey(packet);
|
||||
state.seenGroupKeys.add(groupKey);
|
||||
|
||||
// Track SECTION_END and ERROR packets (both indicate completion)
|
||||
if (
|
||||
packet.obj.type === PacketType.SECTION_END ||
|
||||
packet.obj.type === PacketType.ERROR
|
||||
) {
|
||||
state.groupKeysWithSectionEnd.add(groupKey);
|
||||
}
|
||||
|
||||
// Check if this is the first packet in the group (before adding)
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
const isFirstPacket = !existingGroup;
|
||||
|
||||
// Add packet to group
|
||||
addPacketToGroup(state, packet, groupKey);
|
||||
|
||||
// Categorize on first packet of each group
|
||||
if (isFirstPacket) {
|
||||
if (isToolPacket(packet, false)) {
|
||||
state.toolGroupKeys.add(groupKey);
|
||||
// Track unique tool name
|
||||
const toolName = getToolNameFromPacket(packet);
|
||||
if (toolName) {
|
||||
state.uniqueToolNames.add(toolName);
|
||||
}
|
||||
}
|
||||
if (isDisplayPacket(packet)) {
|
||||
state.displayGroupKeys.add(groupKey);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle specific packet types
|
||||
handleCitationPacket(state, packet);
|
||||
handleDocumentPacket(state, packet);
|
||||
handleStreamingStatusPacket(state, packet);
|
||||
handleStopPacket(state, packet);
|
||||
handleToolAfterMessagePacket(state, packet);
|
||||
}
|
||||
|
||||
export function processPackets(
|
||||
state: ProcessorState,
|
||||
rawPackets: Packet[]
|
||||
): ProcessorState {
|
||||
// Handle reset (packets array shrunk - upstream replaced with shorter list)
|
||||
if (state.lastProcessedIndex > rawPackets.length) {
|
||||
state = createInitialState(state.nodeId);
|
||||
}
|
||||
|
||||
// Process only new packets
|
||||
for (let i = state.lastProcessedIndex; i < rawPackets.length; i++) {
|
||||
const packet = rawPackets[i];
|
||||
if (packet) {
|
||||
processPacket(state, packet);
|
||||
}
|
||||
}
|
||||
|
||||
state.lastProcessedIndex = rawPackets.length;
|
||||
|
||||
// Build result arrays after processing
|
||||
state.toolGroups = buildGroupsFromKeys(state, state.toolGroupKeys);
|
||||
state.potentialDisplayGroups = buildGroupsFromKeys(
|
||||
state,
|
||||
state.displayGroupKeys
|
||||
);
|
||||
state.uniqueToolNamesArray = Array.from(state.uniqueToolNames);
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build GroupedPacket array from a set of group keys.
|
||||
* Filters to only include groups with meaningful content and sorts by turn/tab index.
|
||||
*/
|
||||
function buildGroupsFromKeys(
|
||||
state: ProcessorState,
|
||||
keys: Set<string>
|
||||
): GroupedPacket[] {
|
||||
return Array.from(keys)
|
||||
.map((key) => {
|
||||
const { turn_index, tab_index } = parseToolKey(key);
|
||||
const packets = state.groupedPacketsMap.get(key);
|
||||
// Spread to create new array reference - ensures React detects changes for re-renders
|
||||
return packets ? { turn_index, tab_index, packets: [...packets] } : null;
|
||||
})
|
||||
.filter(
|
||||
(g): g is GroupedPacket => g !== null && hasContentPackets(g.packets)
|
||||
)
|
||||
.sort((a, b) => {
|
||||
if (a.turn_index !== b.turn_index) {
|
||||
return a.turn_index - b.turn_index;
|
||||
}
|
||||
return a.tab_index - b.tab_index;
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
import { useRef, useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Packet,
|
||||
StreamingCitation,
|
||||
StopReason,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { CitationMap } from "@/app/chat/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import {
|
||||
ProcessorState,
|
||||
GroupedPacket,
|
||||
createInitialState,
|
||||
processPackets,
|
||||
} from "@/app/chat/message/messageComponents/timeline/hooks/packetProcessor";
|
||||
import {
|
||||
transformPacketGroups,
|
||||
groupStepsByTurn,
|
||||
TurnGroup,
|
||||
} from "@/app/chat/message/messageComponents/timeline/transformers";
|
||||
|
||||
export interface UsePacketProcessorResult {
|
||||
// Data
|
||||
toolGroups: GroupedPacket[];
|
||||
displayGroups: GroupedPacket[];
|
||||
toolTurnGroups: TurnGroup[];
|
||||
citations: StreamingCitation[];
|
||||
citationMap: CitationMap;
|
||||
documentMap: Map<string, OnyxDocument>;
|
||||
|
||||
// Status (derived from packets)
|
||||
stopPacketSeen: boolean;
|
||||
stopReason: StopReason | undefined;
|
||||
hasSteps: boolean;
|
||||
expectedBranchesPerTurn: Map<number, number>;
|
||||
uniqueToolNames: string[];
|
||||
|
||||
// Completion: stopPacketSeen && renderComplete
|
||||
isComplete: boolean;
|
||||
|
||||
// Callbacks
|
||||
onRenderComplete: () => void;
|
||||
markAllToolsDisplayed: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for processing streaming packets in AgentMessage.
|
||||
*
|
||||
* Architecture:
|
||||
* - Processor state in ref: incremental processing, synchronous, no double render
|
||||
* - Only true UI state: renderComplete (set by callback), forceShowAnswer (override)
|
||||
* - Everything else derived from packets
|
||||
*
|
||||
* Key insight: finalAnswerComing and stopPacketSeen are DERIVED from packets,
|
||||
* not independent state. Only renderComplete needs useState.
|
||||
*/
|
||||
export function usePacketProcessor(
|
||||
rawPackets: Packet[],
|
||||
nodeId: number
|
||||
): UsePacketProcessorResult {
|
||||
// Processor in ref: incremental, synchronous, no double render
|
||||
const stateRef = useRef<ProcessorState>(createInitialState(nodeId));
|
||||
|
||||
// Only TRUE UI state: "has renderer finished?"
|
||||
const [renderComplete, setRenderComplete] = useState(false);
|
||||
|
||||
// Optional override to force showing answer
|
||||
const [forceShowAnswer, setForceShowAnswer] = useState(false);
|
||||
|
||||
// Reset on nodeId change
|
||||
if (stateRef.current.nodeId !== nodeId) {
|
||||
stateRef.current = createInitialState(nodeId);
|
||||
setRenderComplete(false);
|
||||
setForceShowAnswer(false);
|
||||
}
|
||||
|
||||
// Track for transition detection
|
||||
const prevLastProcessed = stateRef.current.lastProcessedIndex;
|
||||
const prevFinalAnswerComing = stateRef.current.finalAnswerComing;
|
||||
|
||||
// Detect stream reset (packets shrunk)
|
||||
if (prevLastProcessed > rawPackets.length) {
|
||||
stateRef.current = createInitialState(nodeId);
|
||||
setRenderComplete(false);
|
||||
setForceShowAnswer(false);
|
||||
}
|
||||
|
||||
// Process packets synchronously (incremental) - only if new packets arrived
|
||||
if (rawPackets.length > stateRef.current.lastProcessedIndex) {
|
||||
stateRef.current = processPackets(stateRef.current, rawPackets);
|
||||
}
|
||||
|
||||
// Reset renderComplete on tool-after-message transition
|
||||
if (prevFinalAnswerComing && !stateRef.current.finalAnswerComing) {
|
||||
setRenderComplete(false);
|
||||
}
|
||||
|
||||
// Access state directly (result arrays are built in processPackets)
|
||||
const state = stateRef.current;
|
||||
|
||||
// Derive displayGroups (not state!)
|
||||
const effectiveFinalAnswerComing = state.finalAnswerComing || forceShowAnswer;
|
||||
const displayGroups = useMemo(() => {
|
||||
if (effectiveFinalAnswerComing || state.toolGroups.length === 0) {
|
||||
return state.potentialDisplayGroups;
|
||||
}
|
||||
return [];
|
||||
}, [
|
||||
effectiveFinalAnswerComing,
|
||||
state.toolGroups.length,
|
||||
state.potentialDisplayGroups,
|
||||
]);
|
||||
|
||||
// Transform toolGroups to timeline format
|
||||
const toolTurnGroups = useMemo(() => {
|
||||
const allSteps = transformPacketGroups(state.toolGroups);
|
||||
return groupStepsByTurn(allSteps);
|
||||
}, [state.toolGroups]);
|
||||
|
||||
// Callback reads from ref: always current value, no ref needed in component
|
||||
const onRenderComplete = useCallback(() => {
|
||||
if (stateRef.current.finalAnswerComing) {
|
||||
setRenderComplete(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const markAllToolsDisplayed = useCallback(() => {
|
||||
setForceShowAnswer(true);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
// Data
|
||||
toolGroups: state.toolGroups,
|
||||
displayGroups,
|
||||
toolTurnGroups,
|
||||
citations: state.citations,
|
||||
citationMap: state.citationMap,
|
||||
documentMap: state.documentMap,
|
||||
|
||||
// Status (derived from packets)
|
||||
stopPacketSeen: state.stopPacketSeen,
|
||||
stopReason: state.stopReason,
|
||||
hasSteps: toolTurnGroups.length > 0,
|
||||
expectedBranchesPerTurn: state.expectedBranches,
|
||||
uniqueToolNames: state.uniqueToolNamesArray,
|
||||
|
||||
// Completion: stopPacketSeen && renderComplete
|
||||
isComplete: state.stopPacketSeen && renderComplete,
|
||||
|
||||
// Callbacks
|
||||
onRenderComplete,
|
||||
markAllToolsDisplayed,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { TurnGroup } from "../transformers";
|
||||
|
||||
export interface TimelineExpansionState {
|
||||
isExpanded: boolean;
|
||||
handleToggle: () => void;
|
||||
parallelActiveTab: string;
|
||||
setParallelActiveTab: (tab: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages expansion state for the timeline.
|
||||
* Auto-collapses when streaming completes and syncs parallel tab selection.
|
||||
*/
|
||||
export function useTimelineExpansion(
|
||||
stopPacketSeen: boolean,
|
||||
lastTurnGroup: TurnGroup | undefined
|
||||
): TimelineExpansionState {
|
||||
const [isExpanded, setIsExpanded] = useState(!stopPacketSeen);
|
||||
const [parallelActiveTab, setParallelActiveTab] = useState<string>("");
|
||||
|
||||
const handleToggle = useCallback(() => {
|
||||
setIsExpanded((prev) => !prev);
|
||||
}, []);
|
||||
|
||||
// Auto-collapse when streaming completes
|
||||
useEffect(() => {
|
||||
if (stopPacketSeen) {
|
||||
setIsExpanded(false);
|
||||
}
|
||||
}, [stopPacketSeen]);
|
||||
|
||||
// Sync active tab when parallel turn group changes
|
||||
useEffect(() => {
|
||||
if (lastTurnGroup?.isParallel && lastTurnGroup.steps.length > 0) {
|
||||
const validTabs = lastTurnGroup.steps.map((s) => s.key);
|
||||
const firstStep = lastTurnGroup.steps[0];
|
||||
if (firstStep && !validTabs.includes(parallelActiveTab)) {
|
||||
setParallelActiveTab(firstStep.key);
|
||||
}
|
||||
}
|
||||
}, [lastTurnGroup, parallelActiveTab]);
|
||||
|
||||
return {
|
||||
isExpanded,
|
||||
handleToggle,
|
||||
parallelActiveTab,
|
||||
setParallelActiveTab,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
import { useMemo } from "react";
|
||||
import { TurnGroup } from "../transformers";
|
||||
import {
|
||||
PacketType,
|
||||
SearchToolPacket,
|
||||
StopReason,
|
||||
CustomToolStart,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { constructCurrentSearchState } from "@/app/chat/message/messageComponents/timeline/renderers/search/searchStateUtils";
|
||||
|
||||
export interface TimelineHeaderResult {
|
||||
headerText: string;
|
||||
hasPackets: boolean;
|
||||
userStopped: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook that determines timeline header state based on current activity.
|
||||
* Returns header text, whether there are packets, and whether user stopped.
|
||||
*/
|
||||
export function useTimelineHeader(
|
||||
turnGroups: TurnGroup[],
|
||||
stopReason?: StopReason
|
||||
): TimelineHeaderResult {
|
||||
return useMemo(() => {
|
||||
const hasPackets = turnGroups.length > 0;
|
||||
const userStopped = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
if (!hasPackets) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
// Get the last (current) turn group
|
||||
const currentTurn = turnGroups[turnGroups.length - 1];
|
||||
if (!currentTurn) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const currentStep = currentTurn.steps[0];
|
||||
if (!currentStep?.packets?.length) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const firstPacket = currentStep.packets[0];
|
||||
if (!firstPacket) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const packetType = firstPacket.obj.type;
|
||||
|
||||
// Determine header based on packet type
|
||||
if (packetType === PacketType.SEARCH_TOOL_START) {
|
||||
const searchState = constructCurrentSearchState(
|
||||
currentStep.packets as SearchToolPacket[]
|
||||
);
|
||||
const headerText = searchState.isInternetSearch
|
||||
? "Searching web"
|
||||
: "Searching internal documents";
|
||||
return { headerText, hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.FETCH_TOOL_START) {
|
||||
return { headerText: "Opening URLs", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.PYTHON_TOOL_START) {
|
||||
return { headerText: "Executing code", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.IMAGE_GENERATION_TOOL_START) {
|
||||
return { headerText: "Generating images", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.CUSTOM_TOOL_START) {
|
||||
const toolName = (firstPacket.obj as CustomToolStart).tool_name;
|
||||
return {
|
||||
headerText: toolName ? `Executing ${toolName}` : "Executing tool",
|
||||
hasPackets,
|
||||
userStopped,
|
||||
};
|
||||
}
|
||||
|
||||
if (packetType === PacketType.REASONING_START) {
|
||||
return { headerText: "Thinking", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.DEEP_RESEARCH_PLAN_START) {
|
||||
return { headerText: "Generating plan", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.RESEARCH_AGENT_START) {
|
||||
return { headerText: "Researching", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}, [turnGroups, stopReason]);
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
import { useMemo } from "react";
|
||||
import {
|
||||
TurnGroup,
|
||||
TransformedStep,
|
||||
} from "@/app/chat/message/messageComponents/timeline/transformers";
|
||||
import { getToolIconByName } from "@/app/chat/message/messageComponents/toolDisplayHelpers";
|
||||
import {
|
||||
isResearchAgentPackets,
|
||||
stepSupportsCompact,
|
||||
} from "@/app/chat/message/messageComponents/timeline/packetHelpers";
|
||||
|
||||
export interface UniqueTool {
|
||||
key: string;
|
||||
name: string;
|
||||
icon: React.JSX.Element;
|
||||
}
|
||||
|
||||
export interface TimelineMetrics {
|
||||
totalSteps: number;
|
||||
isSingleStep: boolean;
|
||||
uniqueTools: UniqueTool[];
|
||||
lastTurnGroup: TurnGroup | undefined;
|
||||
lastStep: TransformedStep | undefined;
|
||||
lastStepIsResearchAgent: boolean;
|
||||
lastStepSupportsCompact: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Memoizes derived metrics from turn groups to avoid recomputation on every render.
|
||||
* Single-pass computation where possible for performance with large packet counts.
|
||||
*/
|
||||
export function useTimelineMetrics(
|
||||
turnGroups: TurnGroup[],
|
||||
uniqueToolNames: string[],
|
||||
userStopped: boolean
|
||||
): TimelineMetrics {
|
||||
return useMemo(() => {
|
||||
// Compute in single pass
|
||||
let totalSteps = 0;
|
||||
for (const tg of turnGroups) {
|
||||
totalSteps += tg.steps.length;
|
||||
}
|
||||
|
||||
const lastTurnGroup = turnGroups[turnGroups.length - 1];
|
||||
const lastStep = lastTurnGroup?.steps[lastTurnGroup.steps.length - 1];
|
||||
|
||||
// Analyze last step packets once
|
||||
const lastStepIsResearchAgent = lastStep
|
||||
? isResearchAgentPackets(lastStep.packets)
|
||||
: false;
|
||||
const lastStepSupportsCompact = lastStep
|
||||
? stepSupportsCompact(lastStep.packets)
|
||||
: false;
|
||||
|
||||
return {
|
||||
totalSteps,
|
||||
isSingleStep: totalSteps === 1 && !userStopped,
|
||||
uniqueTools: uniqueToolNames.map((name) => ({
|
||||
key: name,
|
||||
name,
|
||||
icon: getToolIconByName(name),
|
||||
})),
|
||||
lastTurnGroup,
|
||||
lastStep,
|
||||
lastStepIsResearchAgent,
|
||||
lastStepSupportsCompact,
|
||||
};
|
||||
}, [turnGroups, uniqueToolNames, userStopped]);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
import { Packet, PacketType } from "@/app/chat/services/streamingModels";
|
||||
|
||||
// Packet types with renderers supporting compact mode
|
||||
export const COMPACT_SUPPORTED_PACKET_TYPES = new Set<PacketType>([
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
]);
|
||||
|
||||
// Check if packets belong to a research agent (handles its own Done indicator)
|
||||
export const isResearchAgentPackets = (packets: Packet[]): boolean =>
|
||||
packets.some((p) => p.obj.type === PacketType.RESEARCH_AGENT_START);
|
||||
|
||||
// Check if step supports compact rendering mode
|
||||
export const stepSupportsCompact = (packets: Packet[]): boolean =>
|
||||
packets.some((p) =>
|
||||
COMPACT_SUPPORTED_PACKET_TYPES.has(p.obj.type as PacketType)
|
||||
);
|
||||
@@ -0,0 +1,199 @@
|
||||
import { useEffect, useMemo } from "react";
|
||||
import {
|
||||
PacketType,
|
||||
PythonToolPacket,
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
SectionEnd,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
RenderType,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import { CodeBlock } from "@/app/chat/message/CodeBlock";
|
||||
import hljs from "highlight.js/lib/core";
|
||||
import python from "highlight.js/lib/languages/python";
|
||||
import { SvgTerminal } from "@opal/icons";
|
||||
import FadingEdgeContainer from "@/refresh-components/FadingEdgeContainer";
|
||||
|
||||
// Register Python language for highlighting
|
||||
hljs.registerLanguage("python", python);
|
||||
|
||||
// Component to render syntax-highlighted Python code
|
||||
function HighlightedPythonCode({ code }: { code: string }) {
|
||||
const highlightedHtml = useMemo(() => {
|
||||
try {
|
||||
return hljs.highlight(code, { language: "python" }).value;
|
||||
} catch {
|
||||
return code;
|
||||
}
|
||||
}, [code]);
|
||||
|
||||
return (
|
||||
<span
|
||||
dangerouslySetInnerHTML={{ __html: highlightedHtml }}
|
||||
className="hljs"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Helper function to construct current Python execution state
|
||||
function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
const pythonStart = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.PYTHON_TOOL_START
|
||||
)?.obj as PythonToolStart | null;
|
||||
const pythonDeltas = packets
|
||||
.filter((packet) => packet.obj.type === PacketType.PYTHON_TOOL_DELTA)
|
||||
.map((packet) => packet.obj as PythonToolDelta);
|
||||
const pythonEnd = packets.find(
|
||||
(packet) =>
|
||||
packet.obj.type === PacketType.SECTION_END ||
|
||||
packet.obj.type === PacketType.ERROR
|
||||
)?.obj as SectionEnd | null;
|
||||
|
||||
const code = pythonStart?.code || "";
|
||||
const stdout = pythonDeltas
|
||||
.map((delta) => delta?.stdout || "")
|
||||
.filter((s) => s)
|
||||
.join("");
|
||||
const stderr = pythonDeltas
|
||||
.map((delta) => delta?.stderr || "")
|
||||
.filter((s) => s)
|
||||
.join("");
|
||||
const fileIds = pythonDeltas.flatMap((delta) => delta?.file_ids || []);
|
||||
const isExecuting = pythonStart && !pythonEnd;
|
||||
const isComplete = pythonStart && pythonEnd;
|
||||
const hasError = stderr.length > 0;
|
||||
|
||||
return {
|
||||
code,
|
||||
stdout,
|
||||
stderr,
|
||||
fileIds,
|
||||
isExecuting,
|
||||
isComplete,
|
||||
hasError,
|
||||
};
|
||||
}
|
||||
|
||||
export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
packets,
|
||||
onComplete,
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const { code, stdout, stderr, fileIds, isExecuting, isComplete, hasError } =
|
||||
constructCurrentPythonState(packets);
|
||||
|
||||
useEffect(() => {
|
||||
if (isComplete) {
|
||||
onComplete();
|
||||
}
|
||||
}, [isComplete, onComplete]);
|
||||
|
||||
const status = useMemo(() => {
|
||||
if (isExecuting) {
|
||||
return "Executing Python code...";
|
||||
}
|
||||
if (hasError) {
|
||||
return "Python execution failed";
|
||||
}
|
||||
if (isComplete) {
|
||||
return "Python execution completed";
|
||||
}
|
||||
return "Python execution";
|
||||
}, [isComplete, isExecuting, hasError]);
|
||||
|
||||
// Shared content for all states - used by both FULL and compact modes
|
||||
const content = (
|
||||
<div className="flex flex-col mb-1 space-y-2">
|
||||
{/* Loading indicator when executing */}
|
||||
{isExecuting && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<div className="flex gap-0.5">
|
||||
<div className="w-1 h-1 bg-current rounded-full animate-pulse"></div>
|
||||
<div
|
||||
className="w-1 h-1 bg-current rounded-full animate-pulse"
|
||||
style={{ animationDelay: "0.1s" }}
|
||||
></div>
|
||||
<div
|
||||
className="w-1 h-1 bg-current rounded-full animate-pulse"
|
||||
style={{ animationDelay: "0.2s" }}
|
||||
></div>
|
||||
</div>
|
||||
<span>Running code...</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Code block */}
|
||||
{code && (
|
||||
<div className="prose max-w-full">
|
||||
<CodeBlock className="language-python" codeText={code.trim()}>
|
||||
<HighlightedPythonCode code={code.trim()} />
|
||||
</CodeBlock>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Output */}
|
||||
{stdout && (
|
||||
<div className="rounded-md bg-gray-100 dark:bg-gray-800 p-3">
|
||||
<div className="text-xs font-semibold mb-1 text-gray-600 dark:text-gray-400">
|
||||
Output:
|
||||
</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-gray-900 dark:text-gray-100">
|
||||
{stdout}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{stderr && (
|
||||
<div className="rounded-md bg-red-50 dark:bg-red-900/20 p-3 border border-red-200 dark:border-red-800">
|
||||
<div className="text-xs font-semibold mb-1 text-red-600 dark:text-red-400">
|
||||
Error:
|
||||
</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-red-900 dark:text-red-100">
|
||||
{stderr}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* File count */}
|
||||
{fileIds.length > 0 && (
|
||||
<div className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Generated {fileIds.length} file{fileIds.length !== 1 ? "s" : ""}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* No output fallback - only when complete with no output */}
|
||||
{isComplete && !stdout && !stderr && (
|
||||
<div className="py-2 text-center text-gray-500 dark:text-gray-400">
|
||||
<SvgTerminal className="w-4 h-4 mx-auto mb-1 opacity-50" />
|
||||
<p className="text-xs">No output</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
// FULL mode: render content directly
|
||||
if (renderType === RenderType.FULL) {
|
||||
return children({
|
||||
icon: SvgTerminal,
|
||||
status,
|
||||
content,
|
||||
supportsCompact: true,
|
||||
});
|
||||
}
|
||||
|
||||
// Compact mode: wrap content in FadeDiv
|
||||
return children({
|
||||
icon: SvgTerminal,
|
||||
status,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<FadingEdgeContainer direction="bottom" className="h-24">
|
||||
{content}
|
||||
</FadingEdgeContainer>
|
||||
),
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,104 @@
|
||||
import React, { useMemo, useCallback } from "react";
|
||||
import { FiList } from "react-icons/fi";
|
||||
|
||||
import {
|
||||
DeepResearchPlanPacket,
|
||||
PacketType,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
FullChatState,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import { usePacketAnimationAndCollapse } from "@/app/chat/message/messageComponents/hooks/usePacketAnimationAndCollapse";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import ExpandableTextDisplay from "@/refresh-components/texts/ExpandableTextDisplay";
|
||||
import { mutedTextMarkdownComponents } from "@/app/chat/message/messageComponents/timeline/renderers/sharedMarkdownComponents";
|
||||
|
||||
/**
|
||||
* Renderer for deep research plan packets.
|
||||
* Streams the research plan content with a list icon.
|
||||
* Collapsible and auto-collapses when plan generation is complete.
|
||||
*/
|
||||
export const DeepResearchPlanRenderer: MessageRenderer<
|
||||
DeepResearchPlanPacket,
|
||||
FullChatState
|
||||
> = ({
|
||||
packets,
|
||||
state,
|
||||
onComplete,
|
||||
renderType,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
children,
|
||||
}) => {
|
||||
// Check if plan generation is complete (has SECTION_END)
|
||||
const isComplete = packets.some((p) => p.obj.type === PacketType.SECTION_END);
|
||||
|
||||
// Use shared hook for animation logic (collapse behavior no longer needed)
|
||||
const { displayedPacketCount } = usePacketAnimationAndCollapse({
|
||||
packets,
|
||||
animate,
|
||||
isComplete,
|
||||
onComplete,
|
||||
});
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = useMemo(
|
||||
() =>
|
||||
packets
|
||||
.map((packet) => {
|
||||
if (packet.obj.type === PacketType.DEEP_RESEARCH_PLAN_DELTA) {
|
||||
return packet.obj.content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join(""),
|
||||
[packets]
|
||||
);
|
||||
|
||||
// Animated content for collapsed view (respects streaming animation)
|
||||
const animatedContent = useMemo(() => {
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent;
|
||||
}
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
if (packet.obj.type === PacketType.DEEP_RESEARCH_PLAN_DELTA) {
|
||||
return packet.obj.content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [animate, displayedPacketCount, fullContent, packets]);
|
||||
|
||||
// Markdown renderer callback for ExpandableTextDisplay
|
||||
const renderMarkdown = useCallback(
|
||||
(text: string) => (
|
||||
<MinimalMarkdown
|
||||
content={text}
|
||||
components={mutedTextMarkdownComponents}
|
||||
/>
|
||||
),
|
||||
[]
|
||||
);
|
||||
|
||||
const statusText = isComplete ? "Generated plan" : "Generating plan";
|
||||
|
||||
const planContent = (
|
||||
<ExpandableTextDisplay
|
||||
title="Deep research plan"
|
||||
content={fullContent}
|
||||
displayContent={animatedContent}
|
||||
maxLines={5}
|
||||
renderContent={renderMarkdown}
|
||||
/>
|
||||
);
|
||||
|
||||
return children({
|
||||
icon: FiList,
|
||||
status: statusText,
|
||||
content: planContent,
|
||||
expandedText: planContent,
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,223 @@
|
||||
import React, {
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useCallback,
|
||||
FunctionComponent,
|
||||
} from "react";
|
||||
import { FiTarget } from "react-icons/fi";
|
||||
import { SvgCircle, SvgCheckCircle } from "@opal/icons";
|
||||
import { IconProps } from "@opal/types";
|
||||
|
||||
import {
|
||||
PacketType,
|
||||
Packet,
|
||||
ResearchAgentPacket,
|
||||
ResearchAgentStart,
|
||||
IntermediateReportDelta,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
FullChatState,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import { getToolName } from "@/app/chat/message/messageComponents/toolDisplayHelpers";
|
||||
import { StepContainer } from "@/app/chat/message/messageComponents/timeline/StepContainer";
|
||||
import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererResult,
|
||||
} from "@/app/chat/message/messageComponents/timeline/TimelineRendererComponent";
|
||||
import ExpandableTextDisplay from "@/refresh-components/texts/ExpandableTextDisplay";
|
||||
import { useMarkdownRenderer } from "@/app/chat/message/messageComponents/markdownUtils";
|
||||
|
||||
interface NestedToolGroup {
|
||||
sub_turn_index: number;
|
||||
toolType: string;
|
||||
status: string;
|
||||
isComplete: boolean;
|
||||
packets: Packet[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Renderer for research agent steps in deep research.
|
||||
* Segregates packets by tool and uses StepContainer + TimelineRendererComponent.
|
||||
*/
|
||||
export const ResearchAgentRenderer: MessageRenderer<
|
||||
ResearchAgentPacket,
|
||||
FullChatState
|
||||
> = ({
|
||||
packets,
|
||||
state,
|
||||
onComplete,
|
||||
stopPacketSeen,
|
||||
isLastStep = true,
|
||||
children,
|
||||
}) => {
|
||||
// Extract the research task from the start packet
|
||||
const startPacket = packets.find(
|
||||
(p) => p.obj.type === PacketType.RESEARCH_AGENT_START
|
||||
);
|
||||
const researchTask = startPacket
|
||||
? (startPacket.obj as ResearchAgentStart).research_task
|
||||
: "";
|
||||
|
||||
// Separate parent packets from nested tool packets
|
||||
const { parentPackets, nestedToolGroups } = useMemo(() => {
|
||||
const parent: Packet[] = [];
|
||||
const nestedBySubTurn = new Map<number, Packet[]>();
|
||||
|
||||
packets.forEach((packet) => {
|
||||
const subTurnIndex = packet.placement.sub_turn_index;
|
||||
if (subTurnIndex === undefined || subTurnIndex === null) {
|
||||
parent.push(packet);
|
||||
} else {
|
||||
if (!nestedBySubTurn.has(subTurnIndex)) {
|
||||
nestedBySubTurn.set(subTurnIndex, []);
|
||||
}
|
||||
nestedBySubTurn.get(subTurnIndex)!.push(packet);
|
||||
}
|
||||
});
|
||||
|
||||
// Convert nested packets to groups with metadata
|
||||
const groups: NestedToolGroup[] = Array.from(nestedBySubTurn.entries())
|
||||
.sort(([a], [b]) => a - b)
|
||||
.map(([subTurnIndex, toolPackets]) => {
|
||||
const name = getToolName(toolPackets);
|
||||
const isComplete = toolPackets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.SECTION_END ||
|
||||
p.obj.type === PacketType.REASONING_DONE
|
||||
);
|
||||
return {
|
||||
sub_turn_index: subTurnIndex,
|
||||
toolType: name,
|
||||
status: isComplete ? "Complete" : "Running",
|
||||
isComplete,
|
||||
packets: toolPackets,
|
||||
};
|
||||
});
|
||||
|
||||
return { parentPackets: parent, nestedToolGroups: groups };
|
||||
}, [packets]);
|
||||
|
||||
// Check completion from parent packets
|
||||
const isComplete = parentPackets.some(
|
||||
(p) => p.obj.type === PacketType.SECTION_END
|
||||
);
|
||||
const hasCalledCompleteRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (isComplete && !hasCalledCompleteRef.current) {
|
||||
hasCalledCompleteRef.current = true;
|
||||
onComplete();
|
||||
}
|
||||
}, [isComplete, onComplete]);
|
||||
|
||||
// Build report content from parent packets
|
||||
const fullReportContent = parentPackets
|
||||
.map((packet) => {
|
||||
if (packet.obj.type === PacketType.INTERMEDIATE_REPORT_DELTA) {
|
||||
return (packet.obj as IntermediateReportDelta).content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
|
||||
// Markdown renderer for ExpandableTextDisplay
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
fullReportContent,
|
||||
state,
|
||||
"text-text-03 font-main-ui-body"
|
||||
);
|
||||
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderReport = useCallback(() => renderedContent, [renderedContent]);
|
||||
|
||||
// Build content using StepContainer pattern
|
||||
const researchAgentContent = (
|
||||
<div className="flex flex-col">
|
||||
{/* Research Task - using StepContainer (collapsible) */}
|
||||
{researchTask && (
|
||||
<StepContainer
|
||||
stepIcon={FiTarget as FunctionComponent<IconProps>}
|
||||
header="Research Task"
|
||||
collapsible={true}
|
||||
isLastStep={
|
||||
nestedToolGroups.length === 0 && !fullReportContent && !isComplete
|
||||
}
|
||||
>
|
||||
<div className="text-text-600 text-sm">{researchTask}</div>
|
||||
</StepContainer>
|
||||
)}
|
||||
|
||||
{/* Nested tool calls - using TimelineRendererComponent + StepContainer */}
|
||||
{nestedToolGroups.map((group, index) => {
|
||||
const isLastNestedStep =
|
||||
index === nestedToolGroups.length - 1 &&
|
||||
!fullReportContent &&
|
||||
!isComplete;
|
||||
|
||||
return (
|
||||
<TimelineRendererComponent
|
||||
key={group.sub_turn_index}
|
||||
packets={group.packets}
|
||||
chatState={state}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen && !group.isComplete}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
defaultExpanded={true}
|
||||
isLastStep={isLastNestedStep}
|
||||
>
|
||||
{({ icon, status, content, isExpanded, onToggle }) => (
|
||||
<StepContainer
|
||||
stepIcon={icon as FunctionComponent<IconProps> | undefined}
|
||||
header={status}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={onToggle}
|
||||
collapsible={true}
|
||||
isLastStep={isLastNestedStep}
|
||||
isFirstStep={!researchTask && index === 0}
|
||||
>
|
||||
{content}
|
||||
</StepContainer>
|
||||
)}
|
||||
</TimelineRendererComponent>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Intermediate report - using ExpandableTextDisplay */}
|
||||
{fullReportContent && (
|
||||
<StepContainer
|
||||
stepIcon={SvgCircle as FunctionComponent<IconProps>}
|
||||
header="Research Report"
|
||||
isLastStep={!isComplete}
|
||||
isFirstStep={!researchTask && nestedToolGroups.length === 0}
|
||||
>
|
||||
<ExpandableTextDisplay
|
||||
title="Research Report"
|
||||
content={fullReportContent}
|
||||
maxLines={5}
|
||||
renderContent={renderReport}
|
||||
/>
|
||||
</StepContainer>
|
||||
)}
|
||||
|
||||
{/* Done indicator at end of research agent */}
|
||||
{isComplete && !isLastStep && (
|
||||
<StepContainer
|
||||
stepIcon={SvgCheckCircle}
|
||||
header="Done"
|
||||
isLastStep={isLastStep}
|
||||
isFirstStep={false}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
// Return simplified result (no icon, no status)
|
||||
return children({
|
||||
icon: null,
|
||||
status: null,
|
||||
content: researchAgentContent,
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,135 @@
|
||||
import React from "react";
|
||||
import { FiLink } from "react-icons/fi";
|
||||
import { FetchToolPacket } from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
RenderType,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import { BlinkingDot } from "@/app/chat/message/BlinkingDot";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SearchChipList, SourceInfo } from "../search/SearchChipList";
|
||||
import { getMetadataTags } from "../search";
|
||||
import {
|
||||
constructCurrentFetchState,
|
||||
INITIAL_URLS_TO_SHOW,
|
||||
URLS_PER_EXPANSION,
|
||||
} from "./fetchStateUtils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
const urlToSourceInfo = (url: string, index: number): SourceInfo => ({
|
||||
id: `url-${index}`,
|
||||
title: url,
|
||||
sourceType: ValidSources.Web,
|
||||
sourceUrl: url,
|
||||
});
|
||||
|
||||
const documentToSourceInfo = (doc: OnyxDocument): SourceInfo => ({
|
||||
id: doc.document_id,
|
||||
title: doc.semantic_identifier || doc.link || "",
|
||||
sourceType: doc.source_type || ValidSources.Web,
|
||||
sourceUrl: doc.link,
|
||||
description: doc.blurb,
|
||||
metadata: {
|
||||
date: doc.updated_at || undefined,
|
||||
tags: getMetadataTags(doc.metadata),
|
||||
},
|
||||
});
|
||||
|
||||
export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
|
||||
packets,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const fetchState = constructCurrentFetchState(packets);
|
||||
const { urls, documents, hasStarted, isLoading, isComplete } = fetchState;
|
||||
const isCompact = renderType === RenderType.COMPACT;
|
||||
|
||||
if (!hasStarted) {
|
||||
return children({
|
||||
icon: FiLink,
|
||||
status: null,
|
||||
content: <div />,
|
||||
supportsCompact: true,
|
||||
});
|
||||
}
|
||||
|
||||
const displayDocuments = documents.length > 0;
|
||||
const displayUrls = !displayDocuments && isComplete && urls.length > 0;
|
||||
|
||||
return children({
|
||||
icon: FiLink,
|
||||
status: "Opening URLs:",
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="flex flex-col">
|
||||
{!isCompact &&
|
||||
(displayDocuments ? (
|
||||
<SearchChipList
|
||||
items={documents}
|
||||
initialCount={INITIAL_URLS_TO_SHOW}
|
||||
expansionCount={URLS_PER_EXPANSION}
|
||||
getKey={(doc: OnyxDocument) => doc.document_id}
|
||||
toSourceInfo={(doc: OnyxDocument) => documentToSourceInfo(doc)}
|
||||
onClick={(doc: OnyxDocument) => {
|
||||
if (doc.link) window.open(doc.link, "_blank");
|
||||
}}
|
||||
emptyState={<BlinkingDot />}
|
||||
/>
|
||||
) : displayUrls ? (
|
||||
<SearchChipList
|
||||
items={urls}
|
||||
initialCount={INITIAL_URLS_TO_SHOW}
|
||||
expansionCount={URLS_PER_EXPANSION}
|
||||
getKey={(url: string) => url}
|
||||
toSourceInfo={urlToSourceInfo}
|
||||
onClick={(url: string) => window.open(url, "_blank")}
|
||||
emptyState={<BlinkingDot />}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex flex-wrap gap-x-2 gap-y-2 ml-1">
|
||||
<BlinkingDot />
|
||||
</div>
|
||||
))}
|
||||
|
||||
{(displayDocuments || displayUrls) && (
|
||||
<>
|
||||
{!isCompact && (
|
||||
<Text as="p" mainUiMuted text03>
|
||||
Reading results:
|
||||
</Text>
|
||||
)}
|
||||
{displayDocuments ? (
|
||||
<SearchChipList
|
||||
items={documents}
|
||||
initialCount={INITIAL_URLS_TO_SHOW}
|
||||
expansionCount={URLS_PER_EXPANSION}
|
||||
getKey={(doc: OnyxDocument) => `reading-${doc.document_id}`}
|
||||
toSourceInfo={(doc: OnyxDocument) => documentToSourceInfo(doc)}
|
||||
onClick={(doc: OnyxDocument) => {
|
||||
if (doc.link) window.open(doc.link, "_blank");
|
||||
}}
|
||||
emptyState={<BlinkingDot />}
|
||||
/>
|
||||
) : (
|
||||
<SearchChipList
|
||||
items={urls}
|
||||
initialCount={INITIAL_URLS_TO_SHOW}
|
||||
expansionCount={URLS_PER_EXPANSION}
|
||||
getKey={(url: string, index: number) =>
|
||||
`reading-${url}-${index}`
|
||||
}
|
||||
toSourceInfo={urlToSourceInfo}
|
||||
onClick={(url: string) => window.open(url, "_blank")}
|
||||
emptyState={<BlinkingDot />}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,48 @@
|
||||
import {
|
||||
PacketType,
|
||||
FetchToolPacket,
|
||||
FetchToolUrls,
|
||||
FetchToolDocuments,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
|
||||
export const INITIAL_URLS_TO_SHOW = 3;
|
||||
export const URLS_PER_EXPANSION = 5;
|
||||
export const READING_MIN_DURATION_MS = 1000;
|
||||
export const READ_MIN_DURATION_MS = 1000;
|
||||
|
||||
export interface FetchState {
|
||||
urls: string[];
|
||||
documents: OnyxDocument[];
|
||||
hasStarted: boolean;
|
||||
isLoading: boolean;
|
||||
isComplete: boolean;
|
||||
}
|
||||
|
||||
/** Constructs the current fetch state from fetch tool packets. */
|
||||
export const constructCurrentFetchState = (
|
||||
packets: FetchToolPacket[]
|
||||
): FetchState => {
|
||||
const startPacket = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.FETCH_TOOL_START
|
||||
);
|
||||
const urlsPacket = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.FETCH_TOOL_URLS
|
||||
)?.obj as FetchToolUrls | undefined;
|
||||
const documentsPacket = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.FETCH_TOOL_DOCUMENTS
|
||||
)?.obj as FetchToolDocuments | undefined;
|
||||
const sectionEnd = packets.find(
|
||||
(packet) =>
|
||||
packet.obj.type === PacketType.SECTION_END ||
|
||||
packet.obj.type === PacketType.ERROR
|
||||
);
|
||||
|
||||
const urls = urlsPacket?.urls || [];
|
||||
const documents = documentsPacket?.documents || [];
|
||||
const hasStarted = Boolean(startPacket);
|
||||
const isLoading = hasStarted && !documentsPacket;
|
||||
const isComplete = Boolean(startPacket && sectionEnd);
|
||||
|
||||
return { urls, documents, hasStarted, isLoading, isComplete };
|
||||
};
|
||||
@@ -0,0 +1,10 @@
|
||||
export { FetchToolRenderer } from "./FetchToolRenderer";
|
||||
|
||||
export {
|
||||
constructCurrentFetchState,
|
||||
type FetchState,
|
||||
INITIAL_URLS_TO_SHOW,
|
||||
URLS_PER_EXPANSION,
|
||||
READING_MIN_DURATION_MS,
|
||||
READ_MIN_DURATION_MS,
|
||||
} from "./fetchStateUtils";
|
||||
@@ -0,0 +1,140 @@
|
||||
import React, {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
import {
|
||||
PacketType,
|
||||
ReasoningDelta,
|
||||
ReasoningPacket,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
FullChatState,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import ExpandableTextDisplay from "@/refresh-components/texts/ExpandableTextDisplay";
|
||||
import { mutedTextMarkdownComponents } from "@/app/chat/message/messageComponents/timeline/renderers/sharedMarkdownComponents";
|
||||
import { SvgCircle } from "@opal/icons";
|
||||
|
||||
const THINKING_MIN_DURATION_MS = 500; // 0.5 second minimum for "Thinking" state
|
||||
|
||||
const THINKING_STATUS = "Thinking";
|
||||
|
||||
function constructCurrentReasoningState(packets: ReasoningPacket[]) {
|
||||
const hasStart = packets.some(
|
||||
(p) => p.obj.type === PacketType.REASONING_START
|
||||
);
|
||||
const hasEnd = packets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.SECTION_END ||
|
||||
p.obj.type === PacketType.ERROR ||
|
||||
// Support reasoning_done from backend
|
||||
(p.obj as any).type === PacketType.REASONING_DONE
|
||||
);
|
||||
const deltas = packets
|
||||
.filter((p) => p.obj.type === PacketType.REASONING_DELTA)
|
||||
.map((p) => p.obj as ReasoningDelta);
|
||||
|
||||
const content = deltas.map((d) => d.reasoning).join("");
|
||||
|
||||
return {
|
||||
hasStart,
|
||||
hasEnd,
|
||||
content,
|
||||
};
|
||||
}
|
||||
|
||||
export const ReasoningRenderer: MessageRenderer<
|
||||
ReasoningPacket,
|
||||
FullChatState
|
||||
> = ({ packets, onComplete, animate, children }) => {
|
||||
const { hasStart, hasEnd, content } = useMemo(
|
||||
() => constructCurrentReasoningState(packets),
|
||||
[packets]
|
||||
);
|
||||
|
||||
// Track reasoning timing for minimum display duration
|
||||
const [reasoningStartTime, setReasoningStartTime] = useState<number | null>(
|
||||
null
|
||||
);
|
||||
const timeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const completionHandledRef = useRef(false);
|
||||
|
||||
// Track when reasoning starts
|
||||
useEffect(() => {
|
||||
if ((hasStart || hasEnd) && reasoningStartTime === null) {
|
||||
setReasoningStartTime(Date.now());
|
||||
}
|
||||
}, [hasStart, hasEnd, reasoningStartTime]);
|
||||
|
||||
// Handle reasoning completion with minimum duration
|
||||
useEffect(() => {
|
||||
if (
|
||||
hasEnd &&
|
||||
reasoningStartTime !== null &&
|
||||
!completionHandledRef.current
|
||||
) {
|
||||
completionHandledRef.current = true;
|
||||
const elapsedTime = Date.now() - reasoningStartTime;
|
||||
const minimumThinkingDuration = animate ? THINKING_MIN_DURATION_MS : 0;
|
||||
|
||||
if (elapsedTime >= minimumThinkingDuration) {
|
||||
// Enough time has passed, complete immediately
|
||||
onComplete();
|
||||
} else {
|
||||
// Not enough time has passed, delay completion
|
||||
const remainingTime = minimumThinkingDuration - elapsedTime;
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
onComplete();
|
||||
}, remainingTime);
|
||||
}
|
||||
}
|
||||
}, [hasEnd, reasoningStartTime, animate, onComplete]);
|
||||
|
||||
// Cleanup timeout on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Markdown renderer callback for ExpandableTextDisplay
|
||||
const renderMarkdown = useCallback(
|
||||
(text: string) => (
|
||||
<MinimalMarkdown
|
||||
content={text}
|
||||
components={mutedTextMarkdownComponents}
|
||||
/>
|
||||
),
|
||||
[]
|
||||
);
|
||||
|
||||
if (!hasStart && !hasEnd && content.length === 0) {
|
||||
return children({ icon: SvgCircle, status: null, content: <></> });
|
||||
}
|
||||
|
||||
const reasoningContent = (
|
||||
<ExpandableTextDisplay
|
||||
title="Thinking"
|
||||
content={content}
|
||||
displayContent={content}
|
||||
maxLines={5}
|
||||
renderContent={renderMarkdown}
|
||||
/>
|
||||
);
|
||||
|
||||
return children({
|
||||
icon: SvgCircle,
|
||||
status: THINKING_STATUS,
|
||||
content: reasoningContent,
|
||||
expandedText: reasoningContent,
|
||||
});
|
||||
};
|
||||
|
||||
export default ReasoningRenderer;
|
||||
@@ -0,0 +1,144 @@
|
||||
import React, { JSX, useState, useEffect, useRef } from "react";
|
||||
import { SourceTag, SourceInfo } from "@/refresh-components/buttons/source-tag";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type { SourceInfo };
|
||||
|
||||
const ANIMATION_DELAY_MS = 30;
|
||||
|
||||
export interface SearchChipListProps<T> {
|
||||
items: T[];
|
||||
initialCount: number;
|
||||
expansionCount: number;
|
||||
getKey: (item: T, index: number) => string | number;
|
||||
toSourceInfo: (item: T, index: number) => SourceInfo;
|
||||
onClick?: (item: T) => void;
|
||||
emptyState?: React.ReactNode;
|
||||
className?: string;
|
||||
showDetailsCard?: boolean;
|
||||
}
|
||||
|
||||
type DisplayEntry<T> =
|
||||
| { type: "chip"; item: T; index: number }
|
||||
| { type: "more"; batchId: number };
|
||||
|
||||
export function SearchChipList<T>({
|
||||
items,
|
||||
initialCount,
|
||||
expansionCount,
|
||||
getKey,
|
||||
toSourceInfo,
|
||||
onClick,
|
||||
emptyState,
|
||||
className = "",
|
||||
showDetailsCard,
|
||||
}: SearchChipListProps<T>): JSX.Element {
|
||||
const [displayList, setDisplayList] = useState<DisplayEntry<T>[]>([]);
|
||||
const [batchId, setBatchId] = useState(0);
|
||||
const animatedKeysRef = useRef<Set<string>>(new Set());
|
||||
|
||||
const getEntryKey = (entry: DisplayEntry<T>): string => {
|
||||
if (entry.type === "more") return `more-button-${entry.batchId}`;
|
||||
return String(getKey(entry.item, entry.index));
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const initial: DisplayEntry<T>[] = items
|
||||
.slice(0, initialCount)
|
||||
.map((item, i) => ({ type: "chip" as const, item, index: i }));
|
||||
|
||||
if (items.length > initialCount) {
|
||||
initial.push({ type: "more", batchId: 0 });
|
||||
}
|
||||
|
||||
setDisplayList(initial);
|
||||
setBatchId(0);
|
||||
}, [items, initialCount]);
|
||||
|
||||
const chipCount = displayList.filter((e) => e.type === "chip").length;
|
||||
const remainingCount = items.length - chipCount;
|
||||
const remainingItems = items.slice(chipCount);
|
||||
|
||||
const handleShowMore = () => {
|
||||
const nextBatchId = batchId + 1;
|
||||
|
||||
setDisplayList((prev) => {
|
||||
const withoutButton = prev.filter((e) => e.type !== "more");
|
||||
const currentCount = withoutButton.length;
|
||||
const newCount = Math.min(currentCount + expansionCount, items.length);
|
||||
const newItems: DisplayEntry<T>[] = items
|
||||
.slice(currentCount, newCount)
|
||||
.map((item, i) => ({
|
||||
type: "chip" as const,
|
||||
item,
|
||||
index: currentCount + i,
|
||||
}));
|
||||
|
||||
const updated = [...withoutButton, ...newItems];
|
||||
if (newCount < items.length) {
|
||||
updated.push({ type: "more", batchId: nextBatchId });
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
|
||||
setBatchId(nextBatchId);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
displayList.forEach((entry) =>
|
||||
animatedKeysRef.current.add(getEntryKey(entry))
|
||||
);
|
||||
}, 0);
|
||||
return () => clearTimeout(timer);
|
||||
}, [displayList]);
|
||||
|
||||
let newItemCounter = 0;
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-wrap gap-x-2 gap-y-2", className)}>
|
||||
{displayList.map((entry) => {
|
||||
const key = getEntryKey(entry);
|
||||
const isNew = !animatedKeysRef.current.has(key);
|
||||
const delay = isNew ? newItemCounter++ * ANIMATION_DELAY_MS : 0;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={key}
|
||||
className={cn("text-xs", {
|
||||
"animate-in fade-in slide-in-from-left-2 duration-150": isNew,
|
||||
})}
|
||||
style={
|
||||
isNew
|
||||
? {
|
||||
animationDelay: `${delay}ms`,
|
||||
animationFillMode: "backwards",
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{entry.type === "chip" ? (
|
||||
<SourceTag
|
||||
displayName={toSourceInfo(entry.item, entry.index).title}
|
||||
sources={[toSourceInfo(entry.item, entry.index)]}
|
||||
onSourceClick={onClick ? () => onClick(entry.item) : undefined}
|
||||
showDetailsCard={showDetailsCard}
|
||||
/>
|
||||
) : (
|
||||
<SourceTag
|
||||
displayName={`+${remainingCount} more`}
|
||||
sources={remainingItems.map((item, i) =>
|
||||
toSourceInfo(item, chipCount + i)
|
||||
)}
|
||||
onSourceClick={() => handleShowMore()}
|
||||
showDetailsCard={showDetailsCard}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
{items.length === 0 && emptyState}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
import React from "react";
|
||||
import { SvgSearch, SvgGlobe, SvgSearchMenu } from "@opal/icons";
|
||||
import { SearchToolPacket } from "@/app/chat/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
RenderType,
|
||||
} from "@/app/chat/message/messageComponents/interfaces";
|
||||
import { BlinkingDot } from "@/app/chat/message/BlinkingDot";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SearchChipList, SourceInfo } from "./SearchChipList";
|
||||
import {
|
||||
constructCurrentSearchState,
|
||||
INITIAL_QUERIES_TO_SHOW,
|
||||
QUERIES_PER_EXPANSION,
|
||||
INITIAL_RESULTS_TO_SHOW,
|
||||
RESULTS_PER_EXPANSION,
|
||||
getMetadataTags,
|
||||
} from "./searchStateUtils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
const queryToSourceInfo = (query: string, index: number): SourceInfo => ({
|
||||
id: `query-${index}`,
|
||||
title: query,
|
||||
sourceType: ValidSources.Web,
|
||||
icon: SvgSearch,
|
||||
});
|
||||
|
||||
const resultToSourceInfo = (doc: OnyxDocument): SourceInfo => ({
|
||||
id: doc.document_id,
|
||||
title: doc.semantic_identifier || "",
|
||||
sourceType: doc.source_type,
|
||||
sourceUrl: doc.link,
|
||||
description: doc.blurb,
|
||||
metadata: {
|
||||
date: doc.updated_at || undefined,
|
||||
tags: getMetadataTags(doc.metadata),
|
||||
},
|
||||
});
|
||||
|
||||
export const SearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
|
||||
packets,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const searchState = constructCurrentSearchState(packets);
|
||||
const { queries, results, isSearching, isComplete, isInternetSearch } =
|
||||
searchState;
|
||||
|
||||
const isCompact = renderType === RenderType.COMPACT;
|
||||
|
||||
const icon = isInternetSearch ? SvgGlobe : SvgSearchMenu;
|
||||
const queriesHeader = isInternetSearch
|
||||
? "Searching the web for:"
|
||||
: "Searching internal documents for:";
|
||||
|
||||
if (queries.length === 0) {
|
||||
return children({
|
||||
icon,
|
||||
status: null,
|
||||
content: <div />,
|
||||
supportsCompact: true,
|
||||
});
|
||||
}
|
||||
|
||||
return children({
|
||||
icon,
|
||||
status: queriesHeader,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="flex flex-col">
|
||||
{!isCompact && (
|
||||
<SearchChipList
|
||||
items={queries}
|
||||
initialCount={INITIAL_QUERIES_TO_SHOW}
|
||||
expansionCount={QUERIES_PER_EXPANSION}
|
||||
getKey={(_, index) => index}
|
||||
toSourceInfo={queryToSourceInfo}
|
||||
emptyState={<BlinkingDot />}
|
||||
showDetailsCard={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{(results.length > 0 || queries.length > 0) && (
|
||||
<>
|
||||
{!isCompact && (
|
||||
<Text as="p" mainUiMuted text03>
|
||||
Reading results:
|
||||
</Text>
|
||||
)}
|
||||
<SearchChipList
|
||||
items={results}
|
||||
initialCount={INITIAL_RESULTS_TO_SHOW}
|
||||
expansionCount={RESULTS_PER_EXPANSION}
|
||||
getKey={(doc: OnyxDocument) => doc.document_id}
|
||||
toSourceInfo={(doc: OnyxDocument) => resultToSourceInfo(doc)}
|
||||
onClick={(doc: OnyxDocument) => {
|
||||
if (doc.link) {
|
||||
window.open(doc.link, "_blank");
|
||||
}
|
||||
}}
|
||||
emptyState={<BlinkingDot />}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,18 @@
|
||||
export { SearchToolRenderer } from "./SearchToolRenderer";
|
||||
|
||||
export {
|
||||
constructCurrentSearchState,
|
||||
type SearchState,
|
||||
MAX_TITLE_LENGTH,
|
||||
INITIAL_QUERIES_TO_SHOW,
|
||||
QUERIES_PER_EXPANSION,
|
||||
INITIAL_RESULTS_TO_SHOW,
|
||||
RESULTS_PER_EXPANSION,
|
||||
getMetadataTags,
|
||||
} from "./searchStateUtils";
|
||||
|
||||
export {
|
||||
SearchChipList,
|
||||
type SearchChipListProps,
|
||||
type SourceInfo,
|
||||
} from "./SearchChipList";
|
||||
@@ -0,0 +1,97 @@
|
||||
import {
|
||||
PacketType,
|
||||
SearchToolPacket,
|
||||
SearchToolStart,
|
||||
SearchToolQueriesDelta,
|
||||
SearchToolDocumentsDelta,
|
||||
SectionEnd,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
|
||||
export const MAX_TITLE_LENGTH = 25;
|
||||
|
||||
export const getMetadataTags = (metadata?: {
|
||||
[key: string]: string;
|
||||
}): string[] | undefined => {
|
||||
if (!metadata) return undefined;
|
||||
const tags = Object.values(metadata)
|
||||
.filter((value) => typeof value === "string" && value.length > 0)
|
||||
.slice(0, 2)
|
||||
.map((value) => `# ${value}`);
|
||||
return tags.length > 0 ? tags : undefined;
|
||||
};
|
||||
|
||||
export const INITIAL_QUERIES_TO_SHOW = 3;
|
||||
export const QUERIES_PER_EXPANSION = 5;
|
||||
export const INITIAL_RESULTS_TO_SHOW = 3;
|
||||
export const RESULTS_PER_EXPANSION = 10;
|
||||
|
||||
export interface SearchState {
|
||||
queries: string[];
|
||||
results: OnyxDocument[];
|
||||
isSearching: boolean;
|
||||
hasResults: boolean;
|
||||
isComplete: boolean;
|
||||
isInternetSearch: boolean;
|
||||
}
|
||||
|
||||
/** Constructs the current search state from search tool packets. */
|
||||
export const constructCurrentSearchState = (
|
||||
packets: SearchToolPacket[]
|
||||
): SearchState => {
|
||||
const searchStart = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.SEARCH_TOOL_START
|
||||
)?.obj as SearchToolStart | null;
|
||||
|
||||
const queryDeltas = packets
|
||||
.filter(
|
||||
(packet) => packet.obj.type === PacketType.SEARCH_TOOL_QUERIES_DELTA
|
||||
)
|
||||
.map((packet) => packet.obj as SearchToolQueriesDelta);
|
||||
|
||||
const documentDeltas = packets
|
||||
.filter(
|
||||
(packet) => packet.obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA
|
||||
)
|
||||
.map((packet) => packet.obj as SearchToolDocumentsDelta);
|
||||
|
||||
const searchEnd = packets.find(
|
||||
(packet) =>
|
||||
packet.obj.type === PacketType.SECTION_END ||
|
||||
packet.obj.type === PacketType.ERROR
|
||||
)?.obj as SectionEnd | null;
|
||||
|
||||
// Deduplicate queries using Set for O(n) instead of indexOf which is O(n²)
|
||||
const seenQueries = new Set<string>();
|
||||
const queries = queryDeltas
|
||||
.flatMap((delta) => delta?.queries || [])
|
||||
.filter((query) => {
|
||||
if (seenQueries.has(query)) return false;
|
||||
seenQueries.add(query);
|
||||
return true;
|
||||
});
|
||||
|
||||
const seenDocIds = new Set<string>();
|
||||
const results = documentDeltas
|
||||
.flatMap((delta) => delta?.documents || [])
|
||||
.filter((doc) => {
|
||||
if (!doc || !doc.document_id) return false;
|
||||
if (seenDocIds.has(doc.document_id)) return false;
|
||||
seenDocIds.add(doc.document_id);
|
||||
return true;
|
||||
});
|
||||
|
||||
const isSearching = Boolean(searchStart && !searchEnd);
|
||||
const hasResults = results.length > 0;
|
||||
const isComplete = Boolean(searchStart && searchEnd);
|
||||
const isInternetSearch = searchStart?.is_internet_search || false;
|
||||
|
||||
return {
|
||||
queries,
|
||||
results,
|
||||
isSearching,
|
||||
hasResults,
|
||||
isComplete,
|
||||
isInternetSearch,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,15 @@
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export const mutedTextMarkdownComponents = {
|
||||
p: ({ children }: { children?: React.ReactNode }) => (
|
||||
<Text as="p" text03 mainUiMuted className="!my-1">
|
||||
{children}
|
||||
</Text>
|
||||
),
|
||||
li: ({ children }: { children?: React.ReactNode }) => (
|
||||
<Text as="li" text03 mainUiMuted className="!my-0 !py-0 leading-normal">
|
||||
{children}
|
||||
</Text>
|
||||
),
|
||||
} satisfies Partial<Components>;
|
||||
@@ -0,0 +1,78 @@
|
||||
import { GroupedPacket } from "./hooks/packetProcessor";
|
||||
|
||||
/**
|
||||
* Transformed step data ready for rendering
|
||||
*/
|
||||
export interface TransformedStep {
|
||||
/** Unique key for React rendering */
|
||||
key: string;
|
||||
/** Turn index from packet placement */
|
||||
turnIndex: number;
|
||||
/** Tab index for parallel tools */
|
||||
tabIndex: number;
|
||||
/** Raw packets for content rendering */
|
||||
packets: GroupedPacket["packets"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Group steps by turn_index for detecting parallel tools
|
||||
*/
|
||||
export interface TurnGroup {
|
||||
turnIndex: number;
|
||||
steps: TransformedStep[];
|
||||
/** True if multiple steps have the same turn_index (parallel execution) */
|
||||
isParallel: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a single GroupedPacket into step data
|
||||
*/
|
||||
export function transformPacketGroup(group: GroupedPacket): TransformedStep {
|
||||
return {
|
||||
key: `${group.turn_index}-${group.tab_index}`,
|
||||
turnIndex: group.turn_index,
|
||||
tabIndex: group.tab_index,
|
||||
packets: group.packets,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform all packet groups into step data
|
||||
*/
|
||||
export function transformPacketGroups(
|
||||
groups: GroupedPacket[]
|
||||
): TransformedStep[] {
|
||||
return groups.map(transformPacketGroup);
|
||||
}
|
||||
|
||||
/**
|
||||
* Group transformed steps by turn_index to detect parallel tools
|
||||
*/
|
||||
export function groupStepsByTurn(steps: TransformedStep[]): TurnGroup[] {
|
||||
const turnMap = new Map<number, TransformedStep[]>();
|
||||
|
||||
for (const step of steps) {
|
||||
const existing = turnMap.get(step.turnIndex);
|
||||
if (existing) {
|
||||
existing.push(step);
|
||||
} else {
|
||||
turnMap.set(step.turnIndex, [step]);
|
||||
}
|
||||
}
|
||||
|
||||
const result: TurnGroup[] = [];
|
||||
const sortedTurnIndices = Array.from(turnMap.keys()).sort((a, b) => a - b);
|
||||
|
||||
for (const turnIndex of sortedTurnIndices) {
|
||||
const stepsForTurn = turnMap.get(turnIndex)!;
|
||||
stepsForTurn.sort((a, b) => a.tabIndex - b.tabIndex);
|
||||
|
||||
result.push({
|
||||
turnIndex,
|
||||
steps: stepsForTurn,
|
||||
isParallel: stepsForTurn.length > 1,
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -142,3 +142,31 @@ export function getToolIcon(packets: Packet[]): JSX.Element {
|
||||
return <FiCircle className="w-3.5 h-3.5" />;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool icon by tool name string.
|
||||
* Used when we have pre-computed tool names (e.g., from packet processor).
|
||||
*/
|
||||
export function getToolIconByName(name: string): JSX.Element {
|
||||
switch (name) {
|
||||
case "Web Search":
|
||||
return <FiGlobe className="w-3.5 h-3.5" />;
|
||||
case "Internal Search":
|
||||
return <FiSearch className="w-3.5 h-3.5" />;
|
||||
case "Code Interpreter":
|
||||
return <FiCode className="w-3.5 h-3.5" />;
|
||||
case "Open URLs":
|
||||
return <FiLink className="w-3.5 h-3.5" />;
|
||||
case "Generate Image":
|
||||
return <FiImage className="w-3.5 h-3.5" />;
|
||||
case "Generate plan":
|
||||
return <FiList className="w-3.5 h-3.5" />;
|
||||
case "Research agent":
|
||||
return <FiUsers className="w-3.5 h-3.5" />;
|
||||
case "Thinking":
|
||||
return <BrainIcon className="w-3.5 h-3.5" />;
|
||||
default:
|
||||
// Custom tools or unknown
|
||||
return <FiTool className="w-3.5 h-3.5" />;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,6 +110,11 @@
|
||||
.prose > :first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Remove bottom margin from last child to avoid extra space */
|
||||
.prose > :last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface FadeDivProps {
|
||||
className?: string;
|
||||
fadeClassName?: string;
|
||||
footerClassName?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
const FadeDiv: React.FC<FadeDivProps> = ({
|
||||
className,
|
||||
fadeClassName,
|
||||
footerClassName,
|
||||
children,
|
||||
}) => (
|
||||
<div className={cn("relative w-full", className)}>
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-x-0 -top-8 h-8 bg-gradient-to-b from-transparent to-background pointer-events-none",
|
||||
fadeClassName
|
||||
)}
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-end w-full pt-2 px-2",
|
||||
footerClassName
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default FadeDiv;
|
||||
@@ -195,6 +195,7 @@ const MessageList = React.memo(
|
||||
>
|
||||
<AIMessage
|
||||
rawPackets={message.packets}
|
||||
packetsVersion={message.packetsVersion}
|
||||
chatState={chatStateData}
|
||||
nodeId={message.nodeId}
|
||||
messageId={message.messageId}
|
||||
|
||||
@@ -427,10 +427,7 @@ function ActionsTool({
|
||||
*/
|
||||
function ActionsToolSkeleton() {
|
||||
return Array.from({ length: 3 }).map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="w-full p-3 rounded-12 border bg-background-tint-00"
|
||||
>
|
||||
<Card key={index} padding={1.5}>
|
||||
<LineItemLayout
|
||||
// We provide dummy values here.
|
||||
// The `loading` prop will always render a pulsing box instead, so the dummy-values will actually NOT be rendered at all.
|
||||
@@ -439,7 +436,7 @@ function ActionsToolSkeleton() {
|
||||
rightChildren={<></>}
|
||||
loading
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ function Footer() {
|
||||
}](https://www.onyx.app/) - Open Source AI Platform`;
|
||||
|
||||
return (
|
||||
<footer className="w-full flex flex-row justify-center items-center gap-2 pb-2">
|
||||
<footer className="w-full flex flex-row justify-center items-center gap-2 pb-2 mt-auto">
|
||||
<MinimalMarkdown
|
||||
content={customFooterContent}
|
||||
className={cn("max-w-full text-center")}
|
||||
|
||||
@@ -169,14 +169,14 @@ function LabelLayout({
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent={center ? "center" : "start"}
|
||||
gap={0}
|
||||
gap={0.25}
|
||||
>
|
||||
<Text mainContentEmphasis text04>
|
||||
{title}
|
||||
</Text>
|
||||
{optional && (
|
||||
<Text text03 mainContentMuted>
|
||||
{" (Optional)"}
|
||||
(Optional)
|
||||
</Text>
|
||||
)}
|
||||
</Section>
|
||||
|
||||
@@ -23,11 +23,14 @@ export function checkUserOwnsAssistant(
|
||||
/**
|
||||
* Checks if the given user ID owns the specified assistant.
|
||||
*
|
||||
* Returns true if any of the following conditions are met (and the assistant is not built-in):
|
||||
* - No user ID is provided
|
||||
* - The user is a no-auth user
|
||||
* Returns true if a valid user ID is provided and any of the following conditions
|
||||
* are met (and the assistant is not built-in):
|
||||
* - The user is a no-auth user (authentication is disabled)
|
||||
* - The user ID matches the assistant owner's ID
|
||||
*
|
||||
* Returns false if userId is undefined (e.g., user is loading or unauthenticated)
|
||||
* to prevent granting ownership access prematurely.
|
||||
*
|
||||
* @param userId - The user ID to check ownership for
|
||||
* @param assistant - The assistant to check ownership of
|
||||
* @returns true if the user owns the assistant, false otherwise
|
||||
@@ -37,9 +40,8 @@ export function checkUserIdOwnsAssistant(
|
||||
assistant: MinimalPersonaSnapshot | Persona
|
||||
) {
|
||||
return (
|
||||
(!userId ||
|
||||
checkUserIsNoAuthUser(userId) ||
|
||||
assistant.owner?.id === userId) &&
|
||||
!!userId &&
|
||||
(checkUserIsNoAuthUser(userId) || assistant.owner?.id === userId) &&
|
||||
!assistant.builtin_persona
|
||||
);
|
||||
}
|
||||
|
||||
@@ -169,3 +169,21 @@ export function hasNonImageFiles(
|
||||
): boolean {
|
||||
return files.some((file) => !isImageFile(file.name));
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges multiple refs into a single callback ref.
|
||||
* Useful when a component needs both an internal ref and a forwarded ref.
|
||||
*/
|
||||
export function mergeRefs<T>(
|
||||
...refs: (React.Ref<T> | undefined)[]
|
||||
): React.RefCallback<T> {
|
||||
return (node: T | null) => {
|
||||
refs.forEach((ref) => {
|
||||
if (typeof ref === "function") {
|
||||
ref(node);
|
||||
} else if (ref) {
|
||||
(ref as React.MutableRefObject<T | null>).current = node;
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
const isOwnedByUser = checkUserOwnsAssistant(user, agent);
|
||||
const [hovered, setHovered] = React.useState(false);
|
||||
const shareAgentModal = useCreateModal();
|
||||
const { refresh: refreshAgent } = useAgent(agent.id);
|
||||
const { agent: fullAgent, refresh: refreshAgent } = useAgent(agent.id);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
// Start chat and auto-pin unpinned agents to the sidebar
|
||||
@@ -93,6 +93,7 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
} else {
|
||||
// Revalidate the agent data to reflect the changes
|
||||
refreshAgent();
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
},
|
||||
[agent.id, isPaidEnterpriseFeaturesEnabled, refreshAgent, setPopup]
|
||||
@@ -103,7 +104,13 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
{popup}
|
||||
|
||||
<shareAgentModal.Provider>
|
||||
<ShareAgentModal agent={agent} onShare={handleShare} />
|
||||
<ShareAgentModal
|
||||
agentId={agent.id}
|
||||
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
|
||||
groupIds={fullAgent?.groups ?? []}
|
||||
isPublic={fullAgent?.is_public ?? false}
|
||||
onShare={handleShare}
|
||||
/>
|
||||
</shareAgentModal.Provider>
|
||||
|
||||
<Card
|
||||
|
||||
64
web/src/refresh-components/FadingEdgeContainer.tsx
Normal file
64
web/src/refresh-components/FadingEdgeContainer.tsx
Normal file
@@ -0,0 +1,64 @@
|
||||
import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface FadingEdgeContainerProps {
|
||||
/** Classes applied to the inner scrollable container */
|
||||
className?: string;
|
||||
/** Classes to customize the fade gradient (e.g., height, color) */
|
||||
fadeClassName?: string;
|
||||
children: React.ReactNode;
|
||||
/** Which edge to show the fade on */
|
||||
direction?: "top" | "bottom";
|
||||
}
|
||||
|
||||
/**
|
||||
* A container that adds a gradient fade overlay at the top or bottom edge.
|
||||
*
|
||||
* Use this component to wrap scrollable content where you want to visually
|
||||
* indicate that more content exists beyond the visible area. The fade stays
|
||||
* fixed relative to the container bounds, not the scroll content.
|
||||
*
|
||||
* @example
|
||||
* // Bottom fade for a scrollable list
|
||||
* <FadingEdgeContainer
|
||||
* direction="bottom"
|
||||
* className="max-h-[300px] overflow-y-auto"
|
||||
* >
|
||||
* {items.map(item => <Item key={item.id} />)}
|
||||
* </FadingEdgeContainer>
|
||||
*
|
||||
* @example
|
||||
* // Top fade with custom fade styling
|
||||
* <FadingEdgeContainer
|
||||
* direction="top"
|
||||
* className="max-h-[200px] overflow-y-auto"
|
||||
* fadeClassName="h-12"
|
||||
* >
|
||||
* {content}
|
||||
* </FadingEdgeContainer>
|
||||
*/
|
||||
const FadingEdgeContainer: React.FC<FadingEdgeContainerProps> = ({
|
||||
className,
|
||||
fadeClassName,
|
||||
children,
|
||||
direction = "top",
|
||||
}) => {
|
||||
const isTop = direction === "top";
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div className={className}>{children}</div>
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-x-0 h-8 pointer-events-none z-10",
|
||||
isTop
|
||||
? "top-0 bg-gradient-to-b from-background to-transparent"
|
||||
: "bottom-0 bg-gradient-to-t from-background to-transparent",
|
||||
fadeClassName
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default FadingEdgeContainer;
|
||||
@@ -77,6 +77,7 @@ const useModalContext = () => {
|
||||
const widthClasses = {
|
||||
lg: "w-[80dvw]",
|
||||
md: "w-[60rem]",
|
||||
"md-sm": "w-[40rem]",
|
||||
sm: "w-[32rem]",
|
||||
};
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
export interface SimpleTooltipProps
|
||||
extends React.ComponentPropsWithoutRef<typeof TooltipContent> {
|
||||
disabled?: boolean;
|
||||
tooltip?: React.ReactNode | string;
|
||||
tooltip?: React.ReactNode;
|
||||
children?: React.ReactNode;
|
||||
delayDuration?: number;
|
||||
}
|
||||
|
||||
@@ -1,45 +1,189 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import React, { useRef, useState, useEffect, useMemo } from "react";
|
||||
import * as TabsPrimitive from "@radix-ui/react-tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cn, mergeRefs } from "@/lib/utils";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import { WithoutStyles } from "@/types";
|
||||
import { Section, SectionProps } from "@/layouts/general-layouts";
|
||||
import { IconProps } from "@opal/types";
|
||||
import Text from "./texts/Text";
|
||||
|
||||
/* =============================================================================
|
||||
CONTEXT
|
||||
============================================================================= */
|
||||
|
||||
interface TabsContextValue {
|
||||
variant: "contained" | "pill";
|
||||
}
|
||||
|
||||
const TabsContext = React.createContext<TabsContextValue | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
const useTabsContext = () => {
|
||||
const context = React.useContext(TabsContext);
|
||||
return context; // Returns undefined if used outside Tabs.List (allows explicit override)
|
||||
};
|
||||
|
||||
/**
|
||||
* TABS COMPONENT VARIANTS
|
||||
*
|
||||
* Contained (default):
|
||||
* ┌─────────────────────────────────────────────────┐
|
||||
* │ ┌──────────┐ ╔══════════╗ ┌──────────┐ │
|
||||
* │ │ Tab 1 │ ║ Tab 2 ║ │ Tab 3 │ │ ← gray background
|
||||
* │ └──────────┘ ╚══════════╝ └──────────┘ │
|
||||
* └─────────────────────────────────────────────────┘
|
||||
* ↑ active tab (white bg, shadow)
|
||||
*
|
||||
* Pill:
|
||||
* Tab 1 Tab 2 Tab 3 [Action]
|
||||
* ╔═════╗
|
||||
* ║ ║ ↑ optional rightContent
|
||||
* ────────────╨═════╨─────────────────────────────
|
||||
* ↑ sliding indicator under active tab
|
||||
*
|
||||
* @example
|
||||
* <Tabs defaultValue="tab1">
|
||||
* <Tabs.List variant="pill">
|
||||
* <Tabs.Trigger value="tab1">Overview</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="tab2">Details</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* <Tabs.Content value="tab1">Overview content</Tabs.Content>
|
||||
* <Tabs.Content value="tab2">Details content</Tabs.Content>
|
||||
* </Tabs>
|
||||
*/
|
||||
|
||||
/* =============================================================================
|
||||
VARIANT STYLES
|
||||
Centralized styling definitions for tabs variants.
|
||||
============================================================================= */
|
||||
|
||||
/** Style classes for TabsList variants */
|
||||
const listVariants = {
|
||||
contained: "grid w-full rounded-08 bg-background-tint-03",
|
||||
pill: "relative flex items-center pb-[4px] bg-background-tint-00",
|
||||
} as const;
|
||||
|
||||
/** Base style classes for TabsTrigger variants */
|
||||
const triggerBaseStyles = {
|
||||
contained: "p-2 gap-2",
|
||||
pill: "p-1.5 font-secondary-action transition-all duration-200 ease-out",
|
||||
} as const;
|
||||
|
||||
/** Icon style classes for TabsTrigger variants */
|
||||
const iconVariants = {
|
||||
contained: "stroke-text-03",
|
||||
pill: "stroke-current",
|
||||
} as const;
|
||||
|
||||
/* =============================================================================
|
||||
HOOKS
|
||||
============================================================================= */
|
||||
|
||||
/** Style properties for the pill indicator position */
|
||||
interface IndicatorStyle {
|
||||
left: number;
|
||||
width: number;
|
||||
opacity: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to track and animate a sliding indicator under the active tab.
|
||||
*
|
||||
* Uses MutationObserver to detect when the active tab changes (via data-state
|
||||
* attribute updates from Radix UI) and calculates the indicator position.
|
||||
*
|
||||
* @param listRef - Ref to the TabsList container element
|
||||
* @param enabled - Whether indicator tracking is enabled (only true for pill variant)
|
||||
* @returns Style object with left, width, and opacity for the indicator element
|
||||
*/
|
||||
function usePillIndicator(
|
||||
listRef: React.RefObject<HTMLElement | null>,
|
||||
enabled: boolean
|
||||
): IndicatorStyle {
|
||||
const [style, setStyle] = useState<IndicatorStyle>({
|
||||
left: 0,
|
||||
width: 0,
|
||||
opacity: 0,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!enabled) return;
|
||||
|
||||
const updateIndicator = () => {
|
||||
const list = listRef.current;
|
||||
if (!list) return;
|
||||
|
||||
const activeTab = list.querySelector<HTMLElement>(
|
||||
'[data-state="active"]'
|
||||
);
|
||||
if (activeTab) {
|
||||
const listRect = list.getBoundingClientRect();
|
||||
const tabRect = activeTab.getBoundingClientRect();
|
||||
setStyle({
|
||||
left: tabRect.left - listRect.left,
|
||||
width: tabRect.width,
|
||||
opacity: 1,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
updateIndicator();
|
||||
|
||||
const observer = new MutationObserver(updateIndicator);
|
||||
if (listRef.current) {
|
||||
observer.observe(listRef.current, {
|
||||
attributes: true,
|
||||
subtree: true,
|
||||
attributeFilter: ["data-state"],
|
||||
});
|
||||
}
|
||||
|
||||
return () => observer.disconnect();
|
||||
}, [enabled, listRef]);
|
||||
|
||||
return style;
|
||||
}
|
||||
|
||||
/* =============================================================================
|
||||
SUB-COMPONENTS
|
||||
============================================================================= */
|
||||
|
||||
/**
|
||||
* Renders the bottom line and sliding indicator for the pill variant.
|
||||
* The indicator animates smoothly when switching between tabs.
|
||||
*/
|
||||
function PillIndicator({ style }: { style: IndicatorStyle }) {
|
||||
return (
|
||||
<>
|
||||
<div className="absolute bottom-0 left-0 right-0 h-px bg-border-02 pointer-events-none" />
|
||||
<div
|
||||
className="absolute bottom-0 h-[2px] bg-background-tint-inverted-03 z-10 transition-all duration-200 ease-out pointer-events-none"
|
||||
style={{
|
||||
left: style.left,
|
||||
width: style.width,
|
||||
opacity: style.opacity,
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
/* =============================================================================
|
||||
MAIN COMPONENTS
|
||||
============================================================================= */
|
||||
|
||||
/**
|
||||
* Tabs Root Component
|
||||
*
|
||||
* Container for tab navigation and content. Manages the active tab state.
|
||||
* Supports both controlled and uncontrolled modes.
|
||||
*
|
||||
* @param defaultValue - The tab value that should be active by default (uncontrolled)
|
||||
* @param defaultValue - The tab value that should be active by default (uncontrolled mode)
|
||||
* @param value - The controlled active tab value
|
||||
* @param onValueChange - Callback when the active tab changes
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Uncontrolled tabs (state managed internally)
|
||||
* <Tabs defaultValue="account">
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="account">Account</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="password">Password</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* <Tabs.Content value="account">Account settings content</Tabs.Content>
|
||||
* <Tabs.Content value="password">Password settings content</Tabs.Content>
|
||||
* </Tabs>
|
||||
*
|
||||
* // Controlled tabs (explicit state management)
|
||||
* <Tabs value={activeTab} onValueChange={setActiveTab}>
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="tab1">Tab 1</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="tab2">Tab 2</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* <Tabs.Content value="tab1">Content 1</Tabs.Content>
|
||||
* <Tabs.Content value="tab2">Content 2</Tabs.Content>
|
||||
* </Tabs>
|
||||
* ```
|
||||
* @param onValueChange - Callback fired when the active tab changes
|
||||
*/
|
||||
const TabsRoot = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Root>,
|
||||
@@ -49,44 +193,96 @@ const TabsRoot = React.forwardRef<
|
||||
));
|
||||
TabsRoot.displayName = TabsPrimitive.Root.displayName;
|
||||
|
||||
/* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* Tabs List Props
|
||||
*/
|
||||
interface TabsListProps
|
||||
extends WithoutStyles<
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
|
||||
> {
|
||||
/**
|
||||
* Visual variant of the tabs list.
|
||||
*
|
||||
* - `contained` (default): Rounded background with equal-width tabs in a grid.
|
||||
* Best for primary navigation where tabs should fill available space.
|
||||
*
|
||||
* - `pill`: Transparent background with a sliding underline indicator.
|
||||
* Best for secondary navigation or filter-style tabs with flexible widths.
|
||||
*/
|
||||
variant?: "contained" | "pill";
|
||||
|
||||
/**
|
||||
* Content to render on the right side of the tab list.
|
||||
* Only applies to the `pill` variant (ignored for `contained`).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Tabs.List variant="pill" rightContent={<Button size="sm">Add New</Button>}>
|
||||
* <Tabs.Trigger value="all">All</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="active">Active</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* ```
|
||||
*/
|
||||
rightContent?: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tabs List Component
|
||||
*
|
||||
* Container for tab triggers. Renders as a horizontal list with pill-style background.
|
||||
* Automatically manages keyboard navigation (arrow keys) and accessibility attributes.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Tabs defaultValue="overview">
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="overview">Overview</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="analytics">Analytics</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="settings">Settings</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* <Tabs.Content value="overview">...</Tabs.Content>
|
||||
* <Tabs.Content value="analytics">...</Tabs.Content>
|
||||
* <Tabs.Content value="settings">...</Tabs.Content>
|
||||
* </Tabs>
|
||||
* ```
|
||||
* Container for tab triggers. Renders as a horizontal list with automatic
|
||||
* keyboard navigation (arrow keys, Home/End) and accessibility attributes.
|
||||
*
|
||||
* @remarks
|
||||
* - Default styling: rounded pill background with padding
|
||||
* - Height: 2.5rem (h-10)
|
||||
* - Supports keyboard navigation (Left/Right arrows, Home/End keys)
|
||||
* - Custom className can be added for additional styling if needed
|
||||
* - **Contained**: Uses CSS Grid for equal-width tabs with rounded background
|
||||
* - **Pill**: Uses Flexbox for content-width tabs with animated bottom indicator
|
||||
* - The `variant` prop is automatically propagated to child `Tabs.Trigger` components via context
|
||||
*/
|
||||
const TabsList = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.List>,
|
||||
WithoutStyles<React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>>
|
||||
>((props, ref) => (
|
||||
<TabsPrimitive.List
|
||||
ref={ref}
|
||||
className="flex w-full rounded-08 bg-background-tint-03"
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsListProps
|
||||
>(({ variant = "contained", rightContent, children, ...props }, ref) => {
|
||||
const listRef = useRef<HTMLDivElement>(null);
|
||||
const isPill = variant === "pill";
|
||||
const indicatorStyle = usePillIndicator(listRef, isPill);
|
||||
const contextValue = useMemo(() => ({ variant }), [variant]);
|
||||
|
||||
return (
|
||||
<TabsPrimitive.List
|
||||
ref={mergeRefs(listRef, ref)}
|
||||
className={cn(listVariants[variant])}
|
||||
style={
|
||||
variant === "contained"
|
||||
? {
|
||||
gridTemplateColumns: `repeat(${React.Children.count(
|
||||
children
|
||||
)}, 1fr)`,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
{...props}
|
||||
>
|
||||
<TabsContext.Provider value={contextValue}>
|
||||
{isPill ? (
|
||||
<div className="flex items-center gap-2">{children}</div>
|
||||
) : (
|
||||
children
|
||||
)}
|
||||
|
||||
{isPill && rightContent && (
|
||||
<div className="ml-auto pl-2">{rightContent}</div>
|
||||
)}
|
||||
|
||||
{isPill && <PillIndicator style={indicatorStyle} />}
|
||||
</TabsContext.Provider>
|
||||
</TabsPrimitive.List>
|
||||
);
|
||||
});
|
||||
TabsList.displayName = TabsPrimitive.List.displayName;
|
||||
|
||||
/* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* Tabs Trigger Props
|
||||
*/
|
||||
@@ -97,72 +293,74 @@ interface TabsTriggerProps
|
||||
"children"
|
||||
>
|
||||
> {
|
||||
/**
|
||||
* Visual variant of the tab trigger.
|
||||
* Automatically inherited from the parent `Tabs.List` variant via context.
|
||||
* Can be explicitly set to override the inherited value.
|
||||
*
|
||||
* - `contained` (default): White background with shadow when active
|
||||
* - `pill`: Dark pill background when active, transparent when inactive
|
||||
*/
|
||||
variant?: "contained" | "pill";
|
||||
|
||||
/** Optional tooltip text to display on hover */
|
||||
tooltip?: string;
|
||||
/** Side where tooltip appears. Default: "top" */
|
||||
|
||||
/** Side where tooltip appears. @default "top" */
|
||||
tooltipSide?: "top" | "bottom" | "left" | "right";
|
||||
|
||||
/** Optional icon component to render before the label */
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
children?: string;
|
||||
|
||||
/** Tab label - can be string or ReactNode for custom content */
|
||||
children?: React.ReactNode;
|
||||
|
||||
/** Show loading spinner after label */
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tabs Trigger Component
|
||||
*
|
||||
* Individual tab button that switches the active tab when clicked.
|
||||
* Supports tooltips and disabled state with special tooltip handling.
|
||||
*
|
||||
* @param value - Unique value identifying this tab (required)
|
||||
* @param tooltip - Optional tooltip text shown on hover
|
||||
* @param tooltipSide - Side where tooltip appears (top, bottom, left, right). Default: "top"
|
||||
* @param disabled - Whether the tab is disabled
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Basic tabs
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="home">Home</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="profile">Profile</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
*
|
||||
* // With tooltips
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="edit" tooltip="Edit document">
|
||||
* <SvgEdit />
|
||||
* </Tabs.Trigger>
|
||||
* <Tabs.Trigger value="share" tooltip="Share with others" tooltipSide="bottom">
|
||||
* <SvgShare />
|
||||
* </Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
*
|
||||
* // With disabled state and tooltip
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="admin" disabled tooltip="Admin access required">
|
||||
* Admin Panel
|
||||
* </Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
* ```
|
||||
* Supports icons, tooltips, loading states, and disabled state.
|
||||
*
|
||||
* @remarks
|
||||
* - Active state: white background with shadow
|
||||
* - Inactive state: transparent with hover effect
|
||||
* - Disabled state: reduced opacity, no pointer events
|
||||
* - Tooltips work on both enabled and disabled triggers
|
||||
* - Disabled triggers require special tooltip wrapping to show tooltips
|
||||
* - Automatic focus management and keyboard navigation
|
||||
* - **Contained active**: White background with subtle shadow
|
||||
* - **Pill active**: Dark inverted background
|
||||
* - Tooltips work on disabled triggers via wrapper span technique
|
||||
* - Loading spinner appears after the label text
|
||||
*/
|
||||
const TabsTrigger = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Trigger>,
|
||||
TabsTriggerProps
|
||||
>(
|
||||
(
|
||||
{ tooltip, tooltipSide = "top", icon: Icon, children, disabled, ...props },
|
||||
{
|
||||
variant: variantProp,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
icon: Icon,
|
||||
children,
|
||||
disabled,
|
||||
isLoading,
|
||||
...props
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
const context = useTabsContext();
|
||||
const variant = variantProp ?? context?.variant ?? "contained";
|
||||
|
||||
const inner = (
|
||||
<>
|
||||
{Icon && <Icon size={16} className="stroke-text-03" />}
|
||||
<Text>{children}</Text>
|
||||
{Icon && <Icon size={14} className={cn(iconVariants[variant])} />}
|
||||
{typeof children === "string" ? <Text>{children}</Text> : children}
|
||||
{isLoading && (
|
||||
<span
|
||||
className="inline-block w-3 h-3 border-2 border-text-03 border-t-transparent rounded-full animate-spin"
|
||||
aria-label="Loading"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -171,11 +369,29 @@ const TabsTrigger = React.forwardRef<
|
||||
ref={ref}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
"flex-1 inline-flex items-center justify-center whitespace-nowrap rounded-08 p-2 gap-2",
|
||||
|
||||
// active/inactive states:
|
||||
"data-[state=active]:bg-background-neutral-00 data-[state=active]:text-text-04 data-[state=active]:shadow-01 data-[state=active]:border",
|
||||
"data-[state=inactive]:text-text-03 data-[state=inactive]:bg-transparent data-[state=inactive]:border data-[state=inactive]:border-transparent"
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-08",
|
||||
triggerBaseStyles[variant],
|
||||
variant === "contained" && [
|
||||
"data-[state=active]:bg-background-neutral-00",
|
||||
"data-[state=active]:text-text-04",
|
||||
"data-[state=active]:shadow-01",
|
||||
"data-[state=active]:border",
|
||||
"data-[state=active]:border-border-01",
|
||||
],
|
||||
variant === "pill" && [
|
||||
"data-[state=active]:bg-background-tint-inverted-03",
|
||||
"data-[state=active]:text-text-inverted-05",
|
||||
],
|
||||
variant === "contained" && [
|
||||
"data-[state=inactive]:text-text-03",
|
||||
"data-[state=inactive]:bg-transparent",
|
||||
"data-[state=inactive]:border",
|
||||
"data-[state=inactive]:border-transparent",
|
||||
],
|
||||
variant === "pill" && [
|
||||
"data-[state=inactive]:bg-transparent",
|
||||
"data-[state=inactive]:text-text-03",
|
||||
]
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
@@ -189,9 +405,9 @@ const TabsTrigger = React.forwardRef<
|
||||
</TabsPrimitive.Trigger>
|
||||
);
|
||||
|
||||
// Disabled native buttons don't emit pointer/focus events, so tooltips inside
|
||||
// them won't trigger. Wrap the *entire* trigger with a neutral span only when
|
||||
// disabled so layout stays unchanged for the enabled case.
|
||||
// Disabled native buttons don't emit pointer/focus events, so tooltips
|
||||
// inside them won't trigger. Wrap the entire trigger with a neutral span
|
||||
// only when disabled so layout stays unchanged for the enabled case.
|
||||
if (tooltip && disabled) {
|
||||
return (
|
||||
<SimpleTooltip tooltip={tooltip} side={tooltipSide}>
|
||||
@@ -207,6 +423,8 @@ const TabsTrigger = React.forwardRef<
|
||||
);
|
||||
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
|
||||
|
||||
/* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* Tabs Content Component
|
||||
*
|
||||
@@ -214,34 +432,6 @@ TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
|
||||
* Only the content for the active tab is rendered and visible.
|
||||
*
|
||||
* @param value - The tab value this content is associated with (must match a Tabs.Trigger value)
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Tabs defaultValue="details">
|
||||
* <Tabs.List>
|
||||
* <Tabs.Trigger value="details">Details</Tabs.Trigger>
|
||||
* <Tabs.Trigger value="logs">Logs</Tabs.Trigger>
|
||||
* </Tabs.List>
|
||||
*
|
||||
* <Tabs.Content value="details">
|
||||
* <Section>
|
||||
* <Text>Detailed information goes here</Text>
|
||||
* </Section>
|
||||
* </Tabs.Content>
|
||||
*
|
||||
* <Tabs.Content value="logs">
|
||||
* <Section>
|
||||
* <LogViewer logs={logs} />
|
||||
* </Section>
|
||||
* </Tabs.Content>
|
||||
* </Tabs>
|
||||
* ```
|
||||
*
|
||||
* @remarks
|
||||
* - Content is only mounted/visible when its associated tab is active
|
||||
* - Default top margin of 0.5rem (mt-2) to separate from tabs
|
||||
* - Supports focus management for accessibility
|
||||
* - Custom className can override default styling
|
||||
*/
|
||||
const TabsContent = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Content>,
|
||||
@@ -259,6 +449,10 @@ const TabsContent = React.forwardRef<
|
||||
));
|
||||
TabsContent.displayName = TabsPrimitive.Content.displayName;
|
||||
|
||||
/* =============================================================================
|
||||
EXPORTS
|
||||
============================================================================= */
|
||||
|
||||
export default Object.assign(TabsRoot, {
|
||||
List: TabsList,
|
||||
Trigger: TabsTrigger,
|
||||
|
||||
237
web/src/refresh-components/buttons/source-tag/SourceTag.tsx
Normal file
237
web/src/refresh-components/buttons/source-tag/SourceTag.tsx
Normal file
@@ -0,0 +1,237 @@
|
||||
"use client";
|
||||
|
||||
import { memo, useState, useMemo, useCallback } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { WebResultIcon } from "@/components/WebResultIcon";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import SourceTagDetailsCard, {
|
||||
SourceInfo,
|
||||
} from "@/refresh-components/buttons/source-tag/SourceTagDetailsCard";
|
||||
|
||||
export type { SourceInfo };
|
||||
|
||||
// Variant-specific styles
|
||||
const sizeClasses = {
|
||||
inlineCitation: {
|
||||
container: "rounded-04 p-0.5 gap-0.5",
|
||||
},
|
||||
tag: {
|
||||
container: "rounded-08 p-1 gap-1",
|
||||
},
|
||||
} as const;
|
||||
|
||||
const getIconKey = (source: SourceInfo): string => {
|
||||
if (source.icon) return source.icon.name || "custom";
|
||||
if (source.sourceType === ValidSources.Web && source.sourceUrl) {
|
||||
try {
|
||||
return new URL(source.sourceUrl).hostname;
|
||||
} catch {
|
||||
return source.sourceUrl;
|
||||
}
|
||||
}
|
||||
return source.sourceType;
|
||||
};
|
||||
|
||||
export interface SourceTagProps {
|
||||
/** Use inline citation size (smaller, for use within text) */
|
||||
inlineCitation?: boolean;
|
||||
|
||||
/** Display name shown on the tag (e.g., "Google Drive", "Business Insider") */
|
||||
displayName: string;
|
||||
|
||||
/** URL to display below name (for site type - shows domain) */
|
||||
displayUrl?: string;
|
||||
|
||||
/** Array of sources for navigation in details card */
|
||||
sources: SourceInfo[];
|
||||
|
||||
/** Callback when a source is clicked in the details card */
|
||||
onSourceClick?: () => void;
|
||||
|
||||
/** Whether to show the details card on hover (defaults to true) */
|
||||
showDetailsCard?: boolean;
|
||||
|
||||
/** Additional CSS classes */
|
||||
className?: string;
|
||||
}
|
||||
|
||||
const SourceTagInner = ({
|
||||
inlineCitation,
|
||||
displayName,
|
||||
displayUrl,
|
||||
sources,
|
||||
onSourceClick,
|
||||
showDetailsCard = true,
|
||||
className,
|
||||
}: SourceTagProps) => {
|
||||
const [currentIndex, setCurrentIndex] = useState(0);
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const uniqueSources = useMemo(
|
||||
() =>
|
||||
sources.filter(
|
||||
(source, index, arr) =>
|
||||
arr.findIndex((s) => getIconKey(s) === getIconKey(source)) === index
|
||||
),
|
||||
[sources]
|
||||
);
|
||||
|
||||
const showCount = sources.length > 1;
|
||||
const extraCount = sources.length - 1;
|
||||
|
||||
const size = inlineCitation ? "inlineCitation" : "tag";
|
||||
const styles = sizeClasses[size];
|
||||
|
||||
const handlePrev = useCallback(() => {
|
||||
setCurrentIndex((prev) => Math.max(0, prev - 1));
|
||||
}, []);
|
||||
|
||||
const handleNext = useCallback(() => {
|
||||
setCurrentIndex((prev) => Math.min(sources.length - 1, prev + 1));
|
||||
}, [sources.length]);
|
||||
|
||||
// Reset to first source when tooltip closes
|
||||
const handleOpenChange = useCallback((open: boolean) => {
|
||||
setIsOpen(open);
|
||||
if (!open) {
|
||||
setCurrentIndex(0);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const buttonContent = (
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
"group inline-flex items-center cursor-pointer transition-all duration-150",
|
||||
"appearance-none border-none bg-background-tint-02",
|
||||
isOpen && "bg-background-tint-inverted-03",
|
||||
!showDetailsCard && "hover:bg-background-tint-inverted-03",
|
||||
styles.container,
|
||||
className
|
||||
)}
|
||||
onClick={() => onSourceClick?.()}
|
||||
>
|
||||
{/* Stacked icons container - only for tag variant */}
|
||||
{!inlineCitation && (
|
||||
<div className="flex items-center -space-x-1.5">
|
||||
{uniqueSources.slice(0, 3).map((source, index) => (
|
||||
<div
|
||||
key={source.id}
|
||||
className={cn(
|
||||
"relative flex items-center justify-center p-0.5 rounded-04",
|
||||
"bg-background-tint-00 border transition-colors duration-150",
|
||||
isOpen
|
||||
? "border-background-tint-inverted-03"
|
||||
: "border-background-tint-02",
|
||||
!showDetailsCard &&
|
||||
"group-hover:border-background-tint-inverted-03"
|
||||
)}
|
||||
style={{ zIndex: uniqueSources.slice(0, 3).length - index }}
|
||||
>
|
||||
{source.icon ? (
|
||||
<source.icon size={12} />
|
||||
) : source.sourceType === ValidSources.Web && source.sourceUrl ? (
|
||||
<WebResultIcon url={source.sourceUrl} size={12} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={
|
||||
source.sourceType === ValidSources.Web
|
||||
? ValidSources.Web
|
||||
: source.sourceType
|
||||
}
|
||||
iconSize={12}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={cn("flex items-baseline", !inlineCitation && "pr-0.5")}>
|
||||
<Text
|
||||
figureSmallValue={inlineCitation && !isOpen}
|
||||
figureSmallLabel={inlineCitation && isOpen}
|
||||
secondaryBody={!inlineCitation}
|
||||
text05={isOpen}
|
||||
text03={!isOpen && inlineCitation}
|
||||
text04={!isOpen && !inlineCitation}
|
||||
inverted={isOpen}
|
||||
className={cn(
|
||||
"max-w-[10rem] truncate transition-colors duration-150",
|
||||
!showDetailsCard && "group-hover:text-text-inverted-05"
|
||||
)}
|
||||
>
|
||||
{displayName}
|
||||
</Text>
|
||||
|
||||
{/* Count - for inline citation */}
|
||||
{inlineCitation && showCount && (
|
||||
<Text
|
||||
figureSmallValue
|
||||
text05={isOpen}
|
||||
text03={!isOpen}
|
||||
inverted={isOpen}
|
||||
className={cn(
|
||||
"transition-colors duration-150",
|
||||
!showDetailsCard && "group-hover:text-text-inverted-05"
|
||||
)}
|
||||
>
|
||||
+{extraCount}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{/* URL - for tag variant */}
|
||||
{!inlineCitation && displayUrl && (
|
||||
<Text
|
||||
figureSmallValue
|
||||
text05={isOpen}
|
||||
text02={!isOpen}
|
||||
inverted={isOpen}
|
||||
className={cn(
|
||||
"max-w-[10rem] truncate transition-colors duration-150",
|
||||
!showDetailsCard && "group-hover:text-text-inverted-05"
|
||||
)}
|
||||
>
|
||||
{displayUrl}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
|
||||
if (!showDetailsCard) {
|
||||
return buttonContent;
|
||||
}
|
||||
|
||||
return (
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip open={isOpen} onOpenChange={handleOpenChange}>
|
||||
<TooltipTrigger asChild>{buttonContent}</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side="bottom"
|
||||
align="start"
|
||||
sideOffset={4}
|
||||
className="bg-transparent p-0 shadow-none border-none"
|
||||
>
|
||||
<SourceTagDetailsCard
|
||||
sources={sources}
|
||||
currentIndex={currentIndex}
|
||||
onPrev={handlePrev}
|
||||
onNext={handleNext}
|
||||
/>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
);
|
||||
};
|
||||
|
||||
const SourceTag = memo(SourceTagInner);
|
||||
export default SourceTag;
|
||||
@@ -0,0 +1,172 @@
|
||||
"use client";
|
||||
|
||||
import React, { memo } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { SvgArrowLeft, SvgArrowRight, SvgUser } from "@opal/icons";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { WebResultIcon } from "@/components/WebResultIcon";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { timeAgo } from "@/lib/time";
|
||||
import { IconProps } from "@/components/icons/icons";
|
||||
|
||||
export interface SourceInfo {
|
||||
id: string;
|
||||
title: string;
|
||||
sourceType: ValidSources;
|
||||
sourceUrl?: string;
|
||||
description?: string;
|
||||
metadata?: {
|
||||
author?: string;
|
||||
date?: string | Date;
|
||||
tags?: string[];
|
||||
};
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
}
|
||||
|
||||
interface SourceTagDetailsCardProps {
|
||||
sources: SourceInfo[];
|
||||
currentIndex: number;
|
||||
onPrev: () => void;
|
||||
onNext: () => void;
|
||||
}
|
||||
|
||||
interface MetadataChipProps {
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
text: string;
|
||||
}
|
||||
|
||||
const MetadataChip = memo(function MetadataChip({
|
||||
icon: Icon,
|
||||
text,
|
||||
}: MetadataChipProps) {
|
||||
return (
|
||||
<div className="flex items-center gap-0 bg-background-tint-02 rounded-08 p-1">
|
||||
{Icon && (
|
||||
<div className="flex items-center justify-center p-0.5 w-4 h-4">
|
||||
<Icon className="w-3 h-3 stroke-text-03" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Text secondaryBody text03 className="px-0.5 max-w-[10rem] truncate">
|
||||
{text}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const SourceTagDetailsCardInner = ({
|
||||
sources,
|
||||
currentIndex,
|
||||
onPrev,
|
||||
onNext,
|
||||
}: SourceTagDetailsCardProps) => {
|
||||
const currentSource = sources[currentIndex];
|
||||
if (!currentSource) return null;
|
||||
|
||||
const showNavigation = sources.length > 1;
|
||||
const isFirst = currentIndex === 0;
|
||||
const isLast = currentIndex === sources.length - 1;
|
||||
const isWebSource = currentSource.sourceType === "web";
|
||||
const relativeDate = timeAgo(
|
||||
currentSource.metadata?.date instanceof Date
|
||||
? currentSource.metadata.date.toISOString()
|
||||
: currentSource.metadata?.date
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-[17.5rem] bg-background-neutral-00 border border-border-01 rounded-12 shadow-01 overflow-hidden">
|
||||
{/* Navigation header - only shown for multiple sources */}
|
||||
{showNavigation && (
|
||||
<div className="flex items-center justify-between p-2 bg-background-tint-01 border-b border-border-01">
|
||||
<div className="flex items-center gap-1">
|
||||
<IconButton
|
||||
main
|
||||
internal
|
||||
icon={SvgArrowLeft}
|
||||
onClick={onPrev}
|
||||
disabled={isFirst}
|
||||
className="!p-0.5"
|
||||
/>
|
||||
<IconButton
|
||||
main
|
||||
internal
|
||||
icon={SvgArrowRight}
|
||||
onClick={onNext}
|
||||
disabled={isLast}
|
||||
className="!p-0.5"
|
||||
/>
|
||||
</div>
|
||||
<Text secondaryBody text03 className="px-1">
|
||||
{currentIndex + 1}/{sources.length}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="p-1 flex flex-col gap-1">
|
||||
{/* Header with icon and title */}
|
||||
<div className="flex items-start gap-1 p-0.5 min-h-[1.75rem] w-full text-left hover:bg-background-tint-01 rounded-08 transition-colors">
|
||||
<div className="flex items-center justify-center p-0.5 shrink-0 w-5 h-5">
|
||||
{currentSource.icon ? (
|
||||
<currentSource.icon size={16} />
|
||||
) : isWebSource && currentSource.sourceUrl ? (
|
||||
<WebResultIcon url={currentSource.sourceUrl} size={16} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={
|
||||
currentSource.sourceType === "web"
|
||||
? ValidSources.Web
|
||||
: currentSource.sourceType
|
||||
}
|
||||
iconSize={16}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex-1 min-w-0 px-0.5">
|
||||
<Text
|
||||
mainUiAction
|
||||
text04
|
||||
className="truncate w-full block leading-5"
|
||||
>
|
||||
{currentSource.title}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Metadata row */}
|
||||
{(currentSource.metadata?.author ||
|
||||
currentSource.metadata?.tags?.length ||
|
||||
relativeDate) && (
|
||||
<div className="flex flex-row items-center gap-2 ">
|
||||
<div className="flex flex-wrap gap-1 items-center">
|
||||
{currentSource.metadata?.author && (
|
||||
<MetadataChip
|
||||
icon={SvgUser}
|
||||
text={currentSource.metadata.author}
|
||||
/>
|
||||
)}
|
||||
{currentSource.metadata?.tags
|
||||
?.slice(0, 2)
|
||||
.map((tag) => <MetadataChip key={tag} text={tag} />)}
|
||||
{relativeDate && (
|
||||
<Text secondaryBody text02>
|
||||
{relativeDate}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Description */}
|
||||
{currentSource.description && (
|
||||
<Text secondaryBody text03 as="span" className="line-clamp-4">
|
||||
{currentSource.description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const SourceTagDetailsCard = memo(SourceTagDetailsCardInner);
|
||||
export default SourceTagDetailsCard;
|
||||
6
web/src/refresh-components/buttons/source-tag/index.ts
Normal file
6
web/src/refresh-components/buttons/source-tag/index.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export {
|
||||
default as SourceTag,
|
||||
type SourceTagProps,
|
||||
type SourceInfo,
|
||||
} from "./SourceTag";
|
||||
export { default as SourceTagDetailsCard } from "./SourceTagDetailsCard";
|
||||
180
web/src/refresh-components/texts/ExpandableTextDisplay.tsx
Normal file
180
web/src/refresh-components/texts/ExpandableTextDisplay.tsx
Normal file
@@ -0,0 +1,180 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo } from "react";
|
||||
import * as DialogPrimitive from "@radix-ui/react-dialog";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import FadingEdgeContainer from "@/refresh-components/FadingEdgeContainer";
|
||||
import { SvgDownload, SvgMaximize2, SvgX } from "@opal/icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface ExpandableTextDisplayProps {
|
||||
/** Title shown in header and modal */
|
||||
title: string;
|
||||
/** The full text content to display (used in modal and for copy/download) */
|
||||
content: string;
|
||||
/** Optional content to display in collapsed view (e.g., for streaming animation). Falls back to `content`. */
|
||||
displayContent?: string;
|
||||
/** Subtitle text (e.g., file size). If not provided, calculates from content */
|
||||
subtitle?: string;
|
||||
/** Maximum lines to show in collapsed state (1-6). Values outside this range default to 5. */
|
||||
maxLines?: 1 | 2 | 3 | 4 | 5 | 6;
|
||||
/** Additional className for the container */
|
||||
className?: string;
|
||||
/** Optional custom renderer for content (e.g., markdown). Falls back to plain text. */
|
||||
renderContent?: (content: string) => React.ReactNode;
|
||||
}
|
||||
|
||||
/** Calculate content size in human-readable format */
|
||||
function getContentSize(text: string): string {
|
||||
const bytes = new Blob([text]).size;
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
return `${(bytes / 1024).toFixed(2)} KB`;
|
||||
}
|
||||
|
||||
/** Count lines in text */
|
||||
function getLineCount(text: string): number {
|
||||
return text.split("\n").length;
|
||||
}
|
||||
|
||||
/** Download content as a .txt file */
|
||||
function downloadAsTxt(content: string, filename: string) {
|
||||
const blob = new Blob([content], { type: "text/plain" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
try {
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `${filename}.txt`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
} finally {
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
}
|
||||
|
||||
export default function ExpandableTextDisplay({
|
||||
title,
|
||||
content,
|
||||
displayContent,
|
||||
subtitle,
|
||||
maxLines = 5,
|
||||
className,
|
||||
renderContent,
|
||||
}: ExpandableTextDisplayProps) {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
|
||||
const lineCount = useMemo(() => getLineCount(content), [content]);
|
||||
const contentSize = useMemo(() => getContentSize(content), [content]);
|
||||
const displaySubtitle = subtitle ?? contentSize;
|
||||
|
||||
const handleDownload = () => {
|
||||
const sanitizedTitle = title.replace(/[^a-z0-9]/gi, "_").toLowerCase();
|
||||
downloadAsTxt(content, sanitizedTitle);
|
||||
};
|
||||
|
||||
const lineClampClassMap: Record<number, string> = {
|
||||
1: "line-clamp-1",
|
||||
2: "line-clamp-2",
|
||||
3: "line-clamp-3",
|
||||
4: "line-clamp-4",
|
||||
5: "line-clamp-5",
|
||||
6: "line-clamp-6",
|
||||
};
|
||||
const lineClampClass = lineClampClassMap[maxLines] ?? "line-clamp-5";
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Collapsed View */}
|
||||
<div className={cn("w-full", className)}>
|
||||
<div
|
||||
className={cn(
|
||||
lineClampClass,
|
||||
!renderContent && "whitespace-pre-wrap"
|
||||
)}
|
||||
>
|
||||
{renderContent ? (
|
||||
renderContent(displayContent ?? content)
|
||||
) : (
|
||||
<Text as="p" mainUiMuted text03>
|
||||
{displayContent ?? content}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Expand button */}
|
||||
<div className="flex justify-end mt-1">
|
||||
<IconButton
|
||||
internal
|
||||
icon={SvgMaximize2}
|
||||
tooltip="View Full Text"
|
||||
onClick={() => setIsModalOpen(true)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Expanded Modal */}
|
||||
<Modal open={isModalOpen} onOpenChange={setIsModalOpen}>
|
||||
<Modal.Content height="lg" width="md-sm" preventAccidentalClose={false}>
|
||||
{/* Header */}
|
||||
<div className="flex items-start justify-between px-4 py-3">
|
||||
<div className="flex flex-col">
|
||||
<DialogPrimitive.Title asChild>
|
||||
<Text as="span" text04 headingH3>
|
||||
{title}
|
||||
</Text>
|
||||
</DialogPrimitive.Title>
|
||||
<DialogPrimitive.Description asChild>
|
||||
<Text as="span" text03 secondaryBody>
|
||||
{displaySubtitle}
|
||||
</Text>
|
||||
</DialogPrimitive.Description>
|
||||
</div>
|
||||
<DialogPrimitive.Close asChild>
|
||||
<IconButton
|
||||
icon={SvgX}
|
||||
internal
|
||||
onClick={() => setIsModalOpen(false)}
|
||||
/>
|
||||
</DialogPrimitive.Close>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<Modal.Body>
|
||||
{renderContent ? (
|
||||
renderContent(content)
|
||||
) : (
|
||||
<Text as="p" mainUiMuted text03 className="whitespace-pre-wrap">
|
||||
{content}
|
||||
</Text>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="flex items-center justify-between p-2 bg-background-tint-01">
|
||||
<div className="px-2">
|
||||
<Text as="span" mainUiMuted text03>
|
||||
{lineCount} {lineCount === 1 ? "line" : "lines"}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex items-center gap-1 bg-background-tint-00 p-1 rounded-12">
|
||||
<CopyIconButton
|
||||
internal
|
||||
getCopyText={() => content}
|
||||
tooltip="Copy"
|
||||
/>
|
||||
<IconButton
|
||||
internal
|
||||
icon={SvgDownload}
|
||||
tooltip="Download"
|
||||
onClick={handleDownload}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -883,11 +883,10 @@ export default function AgentEditorPage({
|
||||
message: "Agent deleted successfully",
|
||||
});
|
||||
|
||||
deleteAgentModal.toggle(false);
|
||||
await refreshAgents();
|
||||
router.push("/chat/agents");
|
||||
}
|
||||
|
||||
deleteAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
// FilePickerPopover callbacks - defined outside render to avoid inline functions
|
||||
@@ -1048,15 +1047,18 @@ export default function AgentEditorPage({
|
||||
|
||||
<shareAgentModal.Provider>
|
||||
<ShareAgentModal
|
||||
agent={existingAgent}
|
||||
agentId={existingAgent?.id}
|
||||
userIds={values.shared_user_ids}
|
||||
groupIds={values.shared_group_ids}
|
||||
isPublic={values.is_public}
|
||||
onShare={(userIds, groupIds, isPublic) => {
|
||||
setFieldValue("shared_user_ids", userIds);
|
||||
setFieldValue("shared_group_ids", groupIds);
|
||||
setFieldValue("is_public", isPublic);
|
||||
shareAgentModal.toggle(false);
|
||||
}}
|
||||
/>
|
||||
</shareAgentModal.Provider>
|
||||
|
||||
<deleteAgentModal.Provider>
|
||||
{deleteAgentModal.isOpen && (
|
||||
<ConfirmationModalLayout
|
||||
@@ -1303,6 +1305,7 @@ export default function AgentEditorPage({
|
||||
wrap
|
||||
gap={0.5}
|
||||
justifyContent="start"
|
||||
alignItems="start"
|
||||
>
|
||||
{values.user_file_ids.map((fileId) => {
|
||||
const file = allRecentFiles.find(
|
||||
@@ -1479,7 +1482,7 @@ export default function AgentEditorPage({
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Share This Agent"
|
||||
description="Share this agent with other users, groups, or everyone in your organization. "
|
||||
description="Share this agent with other users, groups, or everyone in your organization."
|
||||
center
|
||||
>
|
||||
<Button
|
||||
|
||||
@@ -4,7 +4,7 @@ import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import FadeDiv from "@/components/FadeDiv";
|
||||
import FadingEdgeContainer from "@/refresh-components/FadingEdgeContainer";
|
||||
import ToolItemSkeleton from "@/sections/actions/skeleton/ToolItemSkeleton";
|
||||
import EnabledCount from "@/refresh-components/EnabledCount";
|
||||
import { SvgEye, SvgXCircle } from "@opal/icons";
|
||||
@@ -57,19 +57,18 @@ const ToolsList: React.FC<ToolsListProps> = ({
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
<FadingEdgeContainer
|
||||
direction="bottom"
|
||||
className={cn(
|
||||
"flex flex-col gap-1 items-start max-h-[30vh] overflow-y-auto w-full",
|
||||
"flex flex-col gap-1 items-start max-h-[30vh] overflow-y-auto",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{isFetching ? (
|
||||
// Show 5 skeleton items while loading
|
||||
Array.from({ length: 5 }).map((_, index) => (
|
||||
<ToolItemSkeleton key={`skeleton-${index}`} />
|
||||
))
|
||||
) : isEmpty ? (
|
||||
// Empty state
|
||||
<div className="flex items-center justify-center w-full py-8">
|
||||
<Text as="p" text03 mainUiBody>
|
||||
{searchQuery ? emptySearchMessage : emptyMessage}
|
||||
@@ -78,11 +77,11 @@ const ToolsList: React.FC<ToolsListProps> = ({
|
||||
) : (
|
||||
children
|
||||
)}
|
||||
</div>
|
||||
</FadingEdgeContainer>
|
||||
|
||||
{/* Footer showing enabled tool count with filter toggle */}
|
||||
{showFooter && !(totalCount === 0) && !isFetching && (
|
||||
<FadeDiv>
|
||||
<div className="pt-2 px-2">
|
||||
<div className="flex items-center justify-between gap-2 w-full">
|
||||
{/* Left action area */}
|
||||
{leftAction}
|
||||
@@ -128,7 +127,7 @@ const ToolsList: React.FC<ToolsListProps> = ({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</FadeDiv>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import {
|
||||
@@ -27,9 +26,6 @@ import { useUser } from "@/components/user/UserProvider";
|
||||
import { Formik, useFormikContext } from "formik";
|
||||
import { useAgent } from "@/hooks/useAgents";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { User } from "@/lib/types";
|
||||
import { UserGroup } from "@/lib/types";
|
||||
import { FullPersona } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
const YOUR_ORGANIZATION_TAB = "Your Organization";
|
||||
const USERS_AND_GROUPS_TAB = "Users & Groups";
|
||||
@@ -44,58 +40,70 @@ interface ShareAgentFormValues {
|
||||
isPublic: boolean;
|
||||
}
|
||||
|
||||
interface ComboBoxOption {
|
||||
value: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ShareAgentFormContent
|
||||
// ============================================================================
|
||||
|
||||
interface ShareAgentFormContentProps {
|
||||
agent?: MinimalPersonaSnapshot;
|
||||
fullAgent: FullPersona | null;
|
||||
usersData: User[];
|
||||
groupsData: UserGroup[];
|
||||
currentUserId: string | undefined;
|
||||
comboBoxOptions: ComboBoxOption[];
|
||||
onClose: () => void;
|
||||
onCopyLink: () => void;
|
||||
agentId?: number;
|
||||
}
|
||||
|
||||
function ShareAgentFormContent({
|
||||
agent,
|
||||
fullAgent,
|
||||
usersData,
|
||||
groupsData,
|
||||
currentUserId,
|
||||
comboBoxOptions,
|
||||
onClose,
|
||||
onCopyLink,
|
||||
}: ShareAgentFormContentProps) {
|
||||
function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
const { values, setFieldValue, handleSubmit, dirty } =
|
||||
useFormikContext<ShareAgentFormValues>();
|
||||
const { data: usersData } = useUsers({ includeApiKeys: false });
|
||||
const { data: groupsData } = useGroups();
|
||||
const { user: currentUser } = useUser();
|
||||
const { agent: fullAgent } = useAgent(agentId ?? null);
|
||||
const shareAgentModal = useModal();
|
||||
|
||||
const acceptedUsers = usersData?.accepted ?? [];
|
||||
const groups = groupsData ?? [];
|
||||
|
||||
// Create options for InputComboBox from all accepted users and groups
|
||||
const comboBoxOptions = useMemo(() => {
|
||||
const userOptions = acceptedUsers.map((user) => ({
|
||||
value: `user-${user.id}`,
|
||||
label: user.email,
|
||||
}));
|
||||
|
||||
const groupOptions = groups.map((group) => ({
|
||||
value: `group-${group.id}`,
|
||||
label: group.name,
|
||||
}));
|
||||
|
||||
return [...userOptions, ...groupOptions];
|
||||
}, [acceptedUsers, groups]);
|
||||
|
||||
// Compute owner and displayed users
|
||||
const ownerId = fullAgent?.owner?.id;
|
||||
const owner = ownerId
|
||||
? usersData.find((user) => user.id === ownerId)
|
||||
: usersData.find((user) => user.id === currentUserId);
|
||||
? acceptedUsers.find((user) => user.id === ownerId)
|
||||
: acceptedUsers.find((user) => user.id === currentUser?.id);
|
||||
const otherUsers = owner
|
||||
? usersData.filter(
|
||||
? acceptedUsers.filter(
|
||||
(user) =>
|
||||
user.id !== owner.id && values.selectedUserIds.includes(user.id)
|
||||
)
|
||||
: usersData;
|
||||
: acceptedUsers;
|
||||
const displayedUsers = [...(owner ? [owner] : []), ...otherUsers];
|
||||
|
||||
// Compute displayed groups based on current form values
|
||||
const displayedGroups = groupsData.filter((group) =>
|
||||
const displayedGroups = groups.filter((group) =>
|
||||
values.selectedGroupIds.includes(group.id)
|
||||
);
|
||||
|
||||
// Handlers
|
||||
function handleClose() {
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
function handleCopyLink() {
|
||||
if (!agentId) return;
|
||||
const url = `${window.location.origin}/chat?assistantId=${agentId}`;
|
||||
navigator.clipboard.writeText(url);
|
||||
}
|
||||
|
||||
function handleComboBoxSelect(selectedValue: string) {
|
||||
if (selectedValue.startsWith("user-")) {
|
||||
const userId = selectedValue.replace("user-", "");
|
||||
@@ -129,7 +137,7 @@ function ShareAgentFormContent({
|
||||
|
||||
return (
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header icon={SvgShare} title="Share Agent" onClose={onClose} />
|
||||
<Modal.Header icon={SvgShare} title="Share Agent" onClose={handleClose} />
|
||||
|
||||
<Modal.Body padding={0.5}>
|
||||
<Card variant="borderless" padding={0.5}>
|
||||
@@ -165,7 +173,7 @@ function ShareAgentFormContent({
|
||||
{/* Shared Users */}
|
||||
{displayedUsers.map((user) => {
|
||||
const isOwner = fullAgent?.owner?.id === user.id;
|
||||
const isCurrentUser = currentUserId === user.id;
|
||||
const isCurrentUser = currentUser?.id === user.id;
|
||||
|
||||
return (
|
||||
<LineItem
|
||||
@@ -173,7 +181,7 @@ function ShareAgentFormContent({
|
||||
icon={SvgUser}
|
||||
description={isCurrentUser ? "You" : undefined}
|
||||
rightChildren={
|
||||
isOwner || (isCurrentUser && !agent) ? (
|
||||
isOwner || (isCurrentUser && !agentId) ? (
|
||||
// Owner will always have the agent "shared" with it.
|
||||
// Therefore, we never render any `IconButton SvgX` to remove it.
|
||||
//
|
||||
@@ -235,14 +243,14 @@ function ShareAgentFormContent({
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
left={
|
||||
agent ? (
|
||||
<Button secondary leftIcon={SvgLink} onClick={onCopyLink}>
|
||||
agentId ? (
|
||||
<Button secondary leftIcon={SvgLink} onClick={handleCopyLink}>
|
||||
Copy Link
|
||||
</Button>
|
||||
) : undefined
|
||||
}
|
||||
cancel={
|
||||
<Button secondary onClick={onClose}>
|
||||
<Button secondary onClick={handleClose}>
|
||||
Done
|
||||
</Button>
|
||||
}
|
||||
@@ -262,54 +270,30 @@ function ShareAgentFormContent({
|
||||
// ============================================================================
|
||||
|
||||
export interface ShareAgentModalProps {
|
||||
agent?: MinimalPersonaSnapshot;
|
||||
agentId?: number;
|
||||
userIds: string[];
|
||||
groupIds: number[];
|
||||
isPublic: boolean;
|
||||
onShare?: (userIds: string[], groupIds: number[], isPublic: boolean) => void;
|
||||
}
|
||||
|
||||
export default function ShareAgentModal({
|
||||
agent,
|
||||
agentId,
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
onShare,
|
||||
}: ShareAgentModalProps) {
|
||||
const { data: usersData } = useUsers({ includeApiKeys: false });
|
||||
const { data: groupsData } = useGroups();
|
||||
const { user: currentUser } = useUser();
|
||||
const shareAgentModal = useModal();
|
||||
const { agent: fullAgent } = useAgent(agent?.id ?? null);
|
||||
|
||||
// Create options for InputComboBox from all accepted users and groups
|
||||
const comboBoxOptions = useMemo(() => {
|
||||
const userOptions = (usersData?.accepted ?? []).map((user) => ({
|
||||
value: `user-${user.id}`,
|
||||
label: user.email,
|
||||
}));
|
||||
|
||||
const groupOptions = (groupsData ?? []).map((group) => ({
|
||||
value: `group-${group.id}`,
|
||||
label: group.name,
|
||||
}));
|
||||
|
||||
return [...userOptions, ...groupOptions];
|
||||
}, [usersData?.accepted, groupsData]);
|
||||
|
||||
const initialValues: ShareAgentFormValues = {
|
||||
selectedUserIds: fullAgent?.users?.map((u) => u.id) ?? [],
|
||||
selectedGroupIds: fullAgent?.groups ?? [],
|
||||
isPublic: fullAgent?.is_public ?? true,
|
||||
selectedUserIds: userIds,
|
||||
selectedGroupIds: groupIds,
|
||||
isPublic: isPublic,
|
||||
};
|
||||
|
||||
function handleSubmit(values: ShareAgentFormValues) {
|
||||
onShare?.(values.selectedUserIds, values.selectedGroupIds, values.isPublic);
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
function handleCopyLink() {
|
||||
if (!agent?.id) return;
|
||||
const url = `${window.location.origin}/chat?assistantId=${agent.id}`;
|
||||
navigator.clipboard.writeText(url);
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -319,16 +303,7 @@ export default function ShareAgentModal({
|
||||
onSubmit={handleSubmit}
|
||||
enableReinitialize
|
||||
>
|
||||
<ShareAgentFormContent
|
||||
agent={agent}
|
||||
fullAgent={fullAgent}
|
||||
usersData={usersData?.accepted ?? []}
|
||||
groupsData={groupsData ?? []}
|
||||
currentUserId={currentUser?.id}
|
||||
comboBoxOptions={comboBoxOptions}
|
||||
onClose={handleClose}
|
||||
onCopyLink={handleCopyLink}
|
||||
/>
|
||||
<ShareAgentFormContent agentId={agentId} />
|
||||
</Formik>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
@@ -435,7 +435,7 @@ const ChatButton = memo(
|
||||
>
|
||||
<Popover.Anchor>
|
||||
<SidebarTab
|
||||
href={`/chat?chatId=${chatSession.id}`}
|
||||
href={isDragging ? undefined : `/chat?chatId=${chatSession.id}`}
|
||||
onClick={handleClick}
|
||||
transient={active}
|
||||
rightChildren={rightMenu}
|
||||
|
||||
@@ -16,6 +16,10 @@ module.exports = {
|
||||
spacing: "margin, padding",
|
||||
},
|
||||
keyframes: {
|
||||
shimmer: {
|
||||
"0%": { backgroundPosition: "100% 0" },
|
||||
"100%": { backgroundPosition: "-100% 0" },
|
||||
},
|
||||
"subtle-pulse": {
|
||||
"0%, 100%": { opacity: 0.9 },
|
||||
"50%": { opacity: 0.5 },
|
||||
@@ -42,6 +46,7 @@ module.exports = {
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
shimmer: "shimmer 1.8s ease-out infinite",
|
||||
"fade-in-up": "fadeInUp 0.5s ease-out",
|
||||
"subtle-pulse": "subtle-pulse 2s ease-in-out infinite",
|
||||
pulse: "pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite",
|
||||
|
||||
Reference in New Issue
Block a user