mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-26 01:52:45 +00:00
Compare commits
17 Commits
multi-mode
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2433a9a4c5 | ||
|
|
60bc8fcac6 | ||
|
|
1ddc958a51 | ||
|
|
de37acbe07 | ||
|
|
08cd2f2c3e | ||
|
|
fc29f20914 | ||
|
|
c43cb80a7a | ||
|
|
56f0be2ec8 | ||
|
|
42f9ddf247 | ||
|
|
a10a85c73c | ||
|
|
31d8ae9718 | ||
|
|
00a0a99842 | ||
|
|
90040f8973 | ||
|
|
4f5d081f26 | ||
|
|
c51a6dbd0d | ||
|
|
8b90ecc189 | ||
|
|
865c893a09 |
4
.github/workflows/deployment.yml
vendored
4
.github/workflows/deployment.yml
vendored
@@ -615,6 +615,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1263,8 +1264,6 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
@@ -1488,6 +1487,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add preferred_response_id and model_display_name to chat_message
|
||||
|
||||
Revision ID: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-22
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3f8b2c1d4e5"
|
||||
down_revision = "25a5501dc766"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"preferred_response_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "model_display_name")
|
||||
op.drop_column("chat_message", "preferred_response_id")
|
||||
@@ -250,20 +250,24 @@ def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
|
||||
raise e
|
||||
|
||||
|
||||
def _is_public_item(drive_item: DriveItem) -> bool:
|
||||
is_public = False
|
||||
def _is_public_item(
|
||||
drive_item: DriveItem,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> bool:
|
||||
if not treat_sharing_link_as_public:
|
||||
return False
|
||||
|
||||
try:
|
||||
permissions = sleep_and_retry(
|
||||
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
|
||||
)
|
||||
for permission in permissions:
|
||||
if permission.link and (
|
||||
permission.link.scope == "anonymous"
|
||||
or permission.link.scope == "organization"
|
||||
if permission.link and permission.link.scope in (
|
||||
"anonymous",
|
||||
"organization",
|
||||
):
|
||||
is_public = True
|
||||
break
|
||||
return is_public
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
|
||||
return False
|
||||
@@ -504,6 +508,7 @@ def get_external_access_from_sharepoint(
|
||||
drive_item: DriveItem | None,
|
||||
site_page: dict[str, Any] | None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get external access information from SharePoint.
|
||||
@@ -563,8 +568,7 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
if drive_item and drive_name:
|
||||
# Here we check if the item have have any public links, if so we return early
|
||||
is_public = _is_public_item(drive_item)
|
||||
is_public = _is_public_item(drive_item, treat_sharing_link_as_public)
|
||||
if is_public:
|
||||
logger.info(f"Item {drive_item.id} is public")
|
||||
return ExternalAccess(
|
||||
|
||||
@@ -44,12 +44,14 @@ def _run_single_search(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
@@ -74,7 +76,7 @@ def stream_search_query(
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
# Get document index.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
@@ -119,6 +121,7 @@ def stream_search_query(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
hybrid_alpha=request.hybrid_alpha,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
@@ -133,6 +136,7 @@ def stream_search_query(
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
request.hybrid_alpha,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
|
||||
@@ -27,15 +27,17 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 30
|
||||
|
||||
hybrid_alpha: float | None = None
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
@@ -67,8 +68,10 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
@@ -80,13 +83,19 @@ def handle_send_search_message(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
Executes a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
is True, executes pure keyword search.
|
||||
|
||||
Returns:
|
||||
StreamingResponse with SSE if stream=True, otherwise SearchFullResponse.
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH:
|
||||
request.hybrid_alpha = 0.0
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -36,13 +35,7 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamingError
|
||||
| CreateChatSessionID
|
||||
)
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
@@ -4,11 +4,9 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
@@ -30,7 +28,6 @@ from onyx.chat.compression import calculate_total_history_tokens
|
||||
from onyx.chat.compression import compress_chat_history
|
||||
from onyx.chat.compression import find_summary_for_branch
|
||||
from onyx.chat.compression import get_compression_params
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import EmptyLLMResponseError
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -62,8 +59,6 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -91,20 +86,16 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
@@ -1078,568 +1069,6 @@ def handle_stream_message_objects(
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def _build_model_display_name(override: LLMOverride) -> str:
|
||||
"""Build a human-readable display name from an LLM override."""
|
||||
if override.display_name:
|
||||
return override.display_name
|
||||
if override.model_version:
|
||||
return override.model_version
|
||||
if override.model_provider:
|
||||
return override.model_provider
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Sentinel placed on the merged queue when a model thread finishes.
|
||||
_MODEL_DONE = object()
|
||||
|
||||
|
||||
class _ModelIndexEmitter(Emitter):
|
||||
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
|
||||
|
||||
Unlike the standard Emitter (which accumulates in a local bus), this puts
|
||||
packets into the shared merged_queue in real-time as they're emitted. This
|
||||
enables true parallel streaming — packets from multiple models interleave
|
||||
on the wire instead of arriving in bursts after each model completes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
|
||||
super().__init__(queue.Queue()) # bus exists for compat, unused
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
self._merged_queue.put((self._model_idx, tagged_packet))
|
||||
|
||||
|
||||
def run_multi_model_stream(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride],
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
# TODO: The setup logic below (session resolution through tool construction)
|
||||
# is duplicated from handle_stream_message_objects. Extract into a shared
|
||||
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
|
||||
# both paths call the same setup code. Tracked as follow-up refactor.
|
||||
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
|
||||
|
||||
Resource management:
|
||||
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
|
||||
- The caller's db_session is used only for setup (before threads launch) and
|
||||
completion callbacks (after threads finish)
|
||||
- ThreadPoolExecutor is bounded to len(overrides) workers
|
||||
- All threads are joined in the finally block regardless of success/failure
|
||||
- Queue-based merging avoids busy-waiting
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
|
||||
if new_msg_req.deep_research:
|
||||
raise ValueError("Multi-model is not supported with deep research")
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
cache: CacheBackend | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
llm_user_identifier = "anonymous_user"
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
try:
|
||||
# ── Session setup (same as single-model path) ──────────────────
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError(
|
||||
"Must specify a chat session id or chat session info"
|
||||
)
|
||||
chat_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
message_text = new_msg_req.message
|
||||
|
||||
# ── Build N LLM instances and validate costs ───────────────────
|
||||
llms: list[LLM] = []
|
||||
model_display_names: list[str] = []
|
||||
for override in llm_overrides:
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
llms.append(llm)
|
||||
model_display_names.append(_build_model_display_name(override))
|
||||
|
||||
# Use first LLM for token counting (context window is checked per-model
|
||||
# but token counting is model-agnostic enough for setup purposes)
|
||||
token_counter = get_llm_token_counter(llms[0])
|
||||
|
||||
verify_user_files(
|
||||
user_files=new_msg_req.file_descriptors,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
project_id=chat_session.project_id,
|
||||
)
|
||||
|
||||
# ── Chat history chain (shared across all models) ──────────────
|
||||
chat_history = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
parent_message = root_message
|
||||
chat_history = []
|
||||
else:
|
||||
parent_message = None
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
if chat_history[i].id == new_msg_req.parent_message_id:
|
||||
parent_message = chat_history[i]
|
||||
chat_history = chat_history[: i + 1]
|
||||
break
|
||||
|
||||
if parent_message is None:
|
||||
raise ValueError(
|
||||
"The new message sent is not on the latest mainline of messages"
|
||||
)
|
||||
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_counter(message_text),
|
||||
message_type=MessageType.USER,
|
||||
files=new_msg_req.file_descriptors,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
chat_history.append(user_message)
|
||||
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
approx_char_count=0,
|
||||
)
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use the smallest context window across all models for safety
|
||||
min_context_window = min(llm.config.max_input_tokens for llm in llms)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=min_context_window,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# ── Reserve N assistant message IDs ────────────────────────────
|
||||
reserved_messages = reserve_multi_model_message_ids(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=user_message.id,
|
||||
model_display_names=model_display_names,
|
||||
)
|
||||
|
||||
yield MultiModelMessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_ids=[m.id for m in reserved_messages],
|
||||
model_names=model_display_names,
|
||||
)
|
||||
|
||||
has_file_reader_tool = any(
|
||||
tool.in_code_tool_id == "file_reader" for tool in all_tools
|
||||
)
|
||||
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=new_msg_req.additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if summary_message is not None:
|
||||
summary_simple = ChatMessageSimple(
|
||||
message=summary_message.message,
|
||||
token_count=summary_message.token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
# ── Stop signal and processing status ──────────────────────────
|
||||
cache = get_cache_backend()
|
||||
reset_cancel_status(chat_session.id, cache)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Release the main session's read transaction before the long stream
|
||||
db_session.commit()
|
||||
|
||||
# ── Parallel model execution ───────────────────────────────────
|
||||
# Each model thread writes tagged packets to this shared queue.
|
||||
# Sentinel _MODEL_DONE signals that a thread finished.
|
||||
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
|
||||
queue.Queue()
|
||||
)
|
||||
|
||||
# Track per-model state containers for completion callbacks
|
||||
state_containers: list[ChatStateContainer] = [
|
||||
ChatStateContainer() for _ in range(n_models)
|
||||
]
|
||||
# Track which models completed successfully (for completion callbacks)
|
||||
model_succeeded: list[bool] = [False] * n_models
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier,
|
||||
session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run a single model in a worker thread.
|
||||
|
||||
Uses _ModelIndexEmitter so packets flow directly to merged_queue
|
||||
in real-time (not batched after completion). This enables true
|
||||
parallel streaming where both models' tokens interleave on the wire.
|
||||
|
||||
DB access: tools may need a session during execution (e.g., search
|
||||
tool). Each thread creates its own session via context manager.
|
||||
"""
|
||||
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
|
||||
sc = state_containers[model_idx]
|
||||
model_llm = llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each model thread gets its own DB session for tool execution.
|
||||
# The session is scoped to the thread and closed when done.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
# Construct tools per-thread with thread-local DB session
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=False,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
model_tools: list[Tool] = []
|
||||
for tool_list in thread_tool_dict.values():
|
||||
model_tools.extend(tool_list)
|
||||
|
||||
# Run the LLM loop — this blocks until the model finishes.
|
||||
# Packets flow to merged_queue in real-time via the emitter.
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
except Exception as e:
|
||||
merged_queue.put((model_idx, e))
|
||||
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models,
|
||||
thread_name_prefix="multi-model",
|
||||
)
|
||||
futures = []
|
||||
try:
|
||||
for i in range(n_models):
|
||||
futures.append(executor.submit(_run_model, i))
|
||||
|
||||
# ── Main thread: merge and yield packets ───────────────────
|
||||
models_remaining = n_models
|
||||
while models_remaining > 0:
|
||||
try:
|
||||
model_idx, item = merged_queue.get(timeout=0.3)
|
||||
except queue.Empty:
|
||||
# Check cancellation during idle periods
|
||||
if not check_is_connected():
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
# Yield error as a tagged StreamingError packet
|
||||
error_msg = str(item)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
# Redact API keys from error messages
|
||||
model_llm = llms[model_idx]
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Packet):
|
||||
# Packet is already tagged with model_index by _ModelIndexEmitter
|
||||
yield item
|
||||
|
||||
# ── Completion: save each successful model's response ──────
|
||||
# Run completion callbacks on the main thread using the main
|
||||
# session. This is safe because all worker threads have exited
|
||||
# by this point (merged_queue fully drained).
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
is_connected=check_is_connected,
|
||||
db_session=db_session,
|
||||
assistant_message=reserved_messages[i],
|
||||
llm=llms[i],
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed completion for model {i} "
|
||||
f"({model_display_names[i]})"
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="complete"),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Ensure all threads are cleaned up regardless of how we exit
|
||||
executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process multi-model chat message.")
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed multi-model chat: {e}")
|
||||
stack_trace = traceback.format_exc()
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
stack_trace=stack_trace,
|
||||
error_code="MULTI_MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
finally:
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error clearing processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
|
||||
@@ -332,6 +332,10 @@ OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
|
||||
else None
|
||||
)
|
||||
ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = (
|
||||
os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -24,11 +24,11 @@ CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
LLM_SOCKET_READ_TIMEOUT = int(
|
||||
os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60"
|
||||
) # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
# Weighting factor between vector and keyword Search; 1 for completely vector
|
||||
# search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from
|
||||
# the env outside of this range will be clipped to the respective end of the
|
||||
# range. Defaults to 0.5.
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import copy
|
||||
import fnmatch
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
@@ -84,6 +85,44 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
|
||||
def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool:
|
||||
"""Check if a site URL matches any of the exclusion glob patterns."""
|
||||
for pattern in excluded_site_patterns:
|
||||
if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch(
|
||||
site_url.rstrip("/"), pattern.rstrip("/")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool:
|
||||
"""Check if a drive item path matches any of the exclusion glob patterns.
|
||||
|
||||
item_path is the relative path within a drive, e.g. "Engineering/API/report.docx".
|
||||
Matches are attempted against the full path and the filename alone so that
|
||||
patterns like "*.tmp" match files at any depth.
|
||||
"""
|
||||
filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path
|
||||
for pattern in excluded_path_patterns:
|
||||
if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str:
|
||||
"""Build the relative path of a drive item from its parentReference.path and name.
|
||||
|
||||
Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx"
|
||||
=> "Eng/API/report.docx"
|
||||
"""
|
||||
if parent_reference_path and "root:/" in parent_reference_path:
|
||||
folder = unquote(parent_reference_path.split("root:/", 1)[1])
|
||||
if folder:
|
||||
return f"{folder}/{item_name}"
|
||||
return item_name
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
|
||||
@@ -478,6 +517,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
access_token: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
|
||||
if not driveitem.name or not driveitem.id:
|
||||
@@ -610,6 +650,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -644,6 +685,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client: GraphClient,
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document:
|
||||
"""Convert a SharePoint site page to a Document object."""
|
||||
# Extract text content from the site page
|
||||
@@ -773,6 +815,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -803,6 +846,7 @@ def _convert_driveitem_to_slim_document(
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -813,6 +857,7 @@ def _convert_driveitem_to_slim_document(
|
||||
graph_client=graph_client,
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return SlimDocument(
|
||||
@@ -827,6 +872,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -836,6 +882,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx=ctx,
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
id = site_page.get("id")
|
||||
if id is None:
|
||||
@@ -855,14 +902,20 @@ class SharepointConnector(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
sites: list[str] = [],
|
||||
excluded_sites: list[str] = [],
|
||||
excluded_paths: list[str] = [],
|
||||
include_site_pages: bool = True,
|
||||
include_site_documents: bool = True,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sites = list(sites)
|
||||
self.excluded_sites = [s for p in excluded_sites if (s := p.strip())]
|
||||
self.excluded_paths = [s for p in excluded_paths if (s := p.strip())]
|
||||
self.treat_sharing_link_as_public = treat_sharing_link_as_public
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
@@ -1233,6 +1286,29 @@ class SharepointConnector(
|
||||
break
|
||||
sites = sites._get_next().execute_query()
|
||||
|
||||
def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool:
|
||||
"""Check if a drive item should be excluded based on excluded_paths patterns."""
|
||||
if not self.excluded_paths:
|
||||
return False
|
||||
relative_path = _build_item_relative_path(
|
||||
driveitem.parent_reference_path, driveitem.name
|
||||
)
|
||||
return _is_path_excluded(relative_path, self.excluded_paths)
|
||||
|
||||
def _filter_excluded_sites(
|
||||
self, site_descriptors: list[SiteDescriptor]
|
||||
) -> list[SiteDescriptor]:
|
||||
"""Remove sites matching any excluded_sites glob pattern."""
|
||||
if not self.excluded_sites:
|
||||
return site_descriptors
|
||||
result = []
|
||||
for sd in site_descriptors:
|
||||
if _is_site_excluded(sd.url, self.excluded_sites):
|
||||
logger.info(f"Excluding site by denylist: {sd.url}")
|
||||
continue
|
||||
result.append(sd)
|
||||
return result
|
||||
|
||||
def fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all_sites().execute_query()
|
||||
|
||||
@@ -1249,7 +1325,7 @@ class SharepointConnector(
|
||||
for site in self._handle_paginated_sites(sites)
|
||||
if "-my.sharepoint" not in site.web_url
|
||||
]
|
||||
return site_descriptors
|
||||
return self._filter_excluded_sites(site_descriptors)
|
||||
|
||||
def _fetch_site_pages(
|
||||
self,
|
||||
@@ -1690,7 +1766,9 @@ class SharepointConnector(
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
|
||||
# Create a temporary checkpoint for hierarchy node tracking
|
||||
temp_checkpoint = SharepointConnectorCheckpoint(has_more=True)
|
||||
@@ -1710,6 +1788,10 @@ class SharepointConnector(
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if drive_web_url:
|
||||
doc_batch.extend(
|
||||
self._yield_drive_hierarchy_node(
|
||||
@@ -1747,6 +1829,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1770,6 +1853,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
@@ -2043,7 +2127,9 @@ class SharepointConnector(
|
||||
and not checkpoint.process_site_pages
|
||||
):
|
||||
logger.info("Initializing SharePoint sites for processing")
|
||||
site_descs = self.site_descriptors or self.fetch_sites()
|
||||
site_descs = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
checkpoint.cached_site_descriptors = deque(site_descs)
|
||||
|
||||
if not checkpoint.cached_site_descriptors:
|
||||
@@ -2264,6 +2350,10 @@ class SharepointConnector(
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
|
||||
logger.debug(
|
||||
f"Skipping duplicate document {driveitem.id} ({driveitem.name})"
|
||||
@@ -2318,6 +2408,7 @@ class SharepointConnector(
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
graph_api_base=self.graph_api_base,
|
||||
access_token=access_token,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
if isinstance(doc_or_failure, Document):
|
||||
@@ -2398,6 +2489,7 @@ class SharepointConnector(
|
||||
include_permissions=include_permissions,
|
||||
# Site pages have the site as their parent
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
|
||||
@@ -17,6 +17,7 @@ def get_sharepoint_external_access(
|
||||
drive_name: str | None = None,
|
||||
site_page: dict[str, Any] | None = None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
if drive_item and drive_item.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -34,7 +35,13 @@ def get_sharepoint_external_access(
|
||||
)
|
||||
|
||||
external_access = get_external_access_func(
|
||||
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
|
||||
ctx,
|
||||
graph_client,
|
||||
drive_name,
|
||||
drive_item,
|
||||
site_page,
|
||||
add_prefix,
|
||||
treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return external_access
|
||||
|
||||
@@ -14,6 +14,10 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
@@ -49,7 +53,7 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def _embed_and_search(
|
||||
def _embed_and_hybrid_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
@@ -81,6 +85,17 @@ def _embed_and_search(
|
||||
return top_chunks
|
||||
|
||||
|
||||
def _keyword_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: NewDocumentIndex,
|
||||
) -> list[InferenceChunk]:
|
||||
return document_index.keyword_retrieval(
|
||||
query=query_request.query,
|
||||
filters=query_request.filters,
|
||||
num_to_retrieve=query_request.limit or NUM_RETURNED_HITS,
|
||||
)
|
||||
|
||||
|
||||
def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -128,21 +143,38 @@ def search_chunks(
|
||||
)
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
if (
|
||||
query_request.hybrid_alpha is not None
|
||||
and query_request.hybrid_alpha == 0.0
|
||||
and isinstance(document_index, OpenSearchOldDocumentIndex)
|
||||
):
|
||||
# If hybrid alpha is explicitly set to keyword only, do pure keyword
|
||||
# search without generating an embedding. This is currently only
|
||||
# supported with OpenSearchDocumentIndex.
|
||||
opensearch_new_document_index: NewDocumentIndex = document_index._real_index
|
||||
run_queries.append(
|
||||
(
|
||||
lambda: _keyword_search(
|
||||
query_request, opensearch_new_document_index
|
||||
),
|
||||
(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_hybrid_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.debug(
|
||||
f"Hybrid search returned no results for query: {query_request.query}with filters: {query_request.filters}"
|
||||
f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}."
|
||||
)
|
||||
return []
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@@ -617,79 +617,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model user message.
|
||||
|
||||
Validates that the user message is a USER type and that the preferred
|
||||
assistant message is a direct child of that user message.
|
||||
"""
|
||||
user_msg = db_session.query(ChatMessage).get(user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -912,8 +839,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -2651,15 +2651,6 @@ class ChatMessage(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# For multi-model turns: the user message points to which assistant response
|
||||
# was selected as the preferred one to continue the conversation with.
|
||||
preferred_response_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=True
|
||||
)
|
||||
|
||||
# The display name of the model that generated this assistant message
|
||||
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# What does this message contain
|
||||
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -2727,12 +2718,6 @@ class ChatMessage(Base):
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
preferred_response: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[preferred_response_id],
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
# Chat messages only need to know their immediate tool call children
|
||||
# If there are nested tool calls, they are stored in the tool_call_children relationship.
|
||||
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
|
||||
|
||||
@@ -381,6 +381,47 @@ class HybridCapable(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs keyword-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query: User query.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs semantic-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query_embedding: Vector representation of the query. Must be of the
|
||||
correct dimensionality for the primary index.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -18,10 +18,13 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.server.metrics.opensearch_search import observe_opensearch_search
|
||||
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -256,7 +259,6 @@ class OpenSearchClient(AbstractContextManager):
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
@@ -304,6 +306,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
emit_metrics: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
@@ -315,6 +318,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
self._emit_metrics = emit_metrics
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -834,7 +838,10 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_pipeline_id: str | None,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
|
||||
"""Searches the index.
|
||||
|
||||
@@ -852,6 +859,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
search_pipeline_id: The ID of the search pipeline to use. If None,
|
||||
the default search pipeline will be used.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -864,21 +873,27 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
result: dict[str, Any]
|
||||
params = {"phase_took": "true"}
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -914,7 +929,11 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
return search_hits
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
In order to take advantage of the performance benefits of only returning
|
||||
@@ -931,6 +950,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
TODO(andrei): Make this a more deep interface; callers shouldn't
|
||||
need to know to set _source: False for example.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -948,13 +969,19 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
|
||||
params = {"phase_took": "true"}
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -1071,6 +1098,20 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
if raise_on_timeout:
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
def _get_emit_metrics_context_manager(
|
||||
self, search_type: OpenSearchSearchType
|
||||
) -> AbstractContextManager[None]:
|
||||
"""
|
||||
Returns a context manager that tracks in-flight OpenSearch searches via
|
||||
a Gauge if emit_metrics is True, otherwise returns a null context
|
||||
manager.
|
||||
"""
|
||||
return (
|
||||
track_opensearch_search_in_progress(search_type)
|
||||
if self._emit_metrics
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def wait_for_opensearch_with_timeout(
|
||||
wait_interval_s: int = 5,
|
||||
|
||||
@@ -53,6 +53,18 @@ DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
|
||||
class OpenSearchSearchType(str, Enum):
|
||||
"""Search type label used for Prometheus metrics."""
|
||||
|
||||
HYBRID = "hybrid"
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HybridSearchSubqueryConfiguration(Enum):
|
||||
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
|
||||
# Current default.
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
@@ -900,6 +901,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
@@ -923,6 +925,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -948,6 +952,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
search_type=OpenSearchSearchType.HYBRID,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
@@ -970,6 +975,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -989,6 +996,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.KEYWORD,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
@@ -1009,6 +1017,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -1028,6 +1038,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.SEMANTIC,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
@@ -1059,6 +1070,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.RANDOM,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
@@ -48,13 +50,21 @@ from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
# See https://docs.opensearch.org/latest/query-dsl/term/terms/.
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]]
|
||||
TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]]
|
||||
|
||||
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
def _get_hybrid_search_normalization_weights() -> list[float]:
|
||||
if (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
@@ -316,6 +326,9 @@ class DocumentQuery:
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
@@ -419,6 +432,9 @@ class DocumentQuery:
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
@@ -498,6 +514,9 @@ class DocumentQuery:
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
@@ -763,8 +782,9 @@ class DocumentQuery:
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
# The title fields are strongly discounted as
|
||||
# they are included in the content. This just
|
||||
# acts as a minor boost.
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
@@ -779,6 +799,9 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Analyzes the query and returns results which match any
|
||||
# of the query's terms. More matches result in higher
|
||||
# scores.
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
@@ -788,18 +811,21 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Matches an exact phrase in a specified order.
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# The number of words permitted between words of
|
||||
# a query phrase and still result in a match.
|
||||
"slop": 1,
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
# Ensures at least one match subquery from the query is present
|
||||
# in the document. This defaults to 1, unless a filter or must
|
||||
# clause is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
@@ -833,7 +859,14 @@ class DocumentQuery:
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/.
|
||||
|
||||
TODO(ENG-3874): The terms queries returned by this function can be made
|
||||
more performant for large cardinality sets by sorting the values by
|
||||
their UTF-8 byte order.
|
||||
|
||||
TODO(ENG-3875): This function can take even better advantage of filter
|
||||
caching by grouping "static" filters together into one sub-clause.
|
||||
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
@@ -878,6 +911,14 @@ class DocumentQuery:
|
||||
the assistant. Matches chunks where ancestor_hierarchy_node_ids
|
||||
contains any of these values.
|
||||
|
||||
Raises:
|
||||
ValueError: document_id and attached_document_ids were supplied
|
||||
together. This is not allowed because they operate on the same
|
||||
schema field, and it does not semantically make sense to use
|
||||
them together.
|
||||
ValueError: Too many of one of the collection arguments was
|
||||
supplied.
|
||||
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
@@ -885,61 +926,156 @@ class DocumentQuery:
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]:
|
||||
"""Returns a filter for the access control list.
|
||||
|
||||
Since this returns an isolated bool should clause, it can be cached
|
||||
in OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
access_control_list: The access control list to restrict
|
||||
documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of access control list entries is greater
|
||||
than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
|
||||
Returns:
|
||||
A filter for the access control list.
|
||||
"""
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
acl_visibility_filter: dict[str, dict[str, Any]] = {
|
||||
"bool": {
|
||||
"should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}],
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
if access_control_list:
|
||||
if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause
|
||||
# because Lucene will optimize the filtering for large sets of
|
||||
# terms. Small sets of terms are not expected to perform any
|
||||
# differently than individual term clauses.
|
||||
acl_subclause: TermsQuery[str] = {
|
||||
"terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""Returns a filter for the source types.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
source_types: The source types to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of source types is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the source types.
|
||||
"""
|
||||
if not source_types:
|
||||
raise ValueError(
|
||||
"source_types cannot be empty if trying to create a source type filter."
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
return tag_filter
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {
|
||||
"terms": {
|
||||
SOURCE_TYPE_FIELD_NAME: [
|
||||
source_type.value for source_type in source_types
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the tags.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
tags: The tags to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of tags is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the tags.
|
||||
"""
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
"tags cannot be empty if trying to create a tag filter."
|
||||
)
|
||||
return document_set_filter
|
||||
if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str_list = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags
|
||||
]
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the document sets.
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
document_sets: The document sets to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document sets is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document sets.
|
||||
"""
|
||||
if not document_sets:
|
||||
raise ValueError(
|
||||
"document_sets cannot be empty if trying to create a document set filter."
|
||||
)
|
||||
if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> TermQuery[int]:
|
||||
return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> TermQuery[int]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
@@ -947,7 +1083,9 @@ class DocumentQuery:
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
@@ -982,25 +1120,77 @@ class DocumentQuery:
|
||||
|
||||
def _get_attached_document_id_filter(
|
||||
doc_ids: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for documents explicitly attached to an assistant."""
|
||||
# Logical OR operator on its elements.
|
||||
doc_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for doc_id in doc_ids:
|
||||
doc_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": doc_id}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""
|
||||
Returns a filter for documents explicitly attached to an assistant.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
doc_ids: The document IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document IDs.
|
||||
"""
|
||||
if not doc_ids:
|
||||
raise ValueError(
|
||||
"doc_ids cannot be empty if trying to create a document ID filter."
|
||||
)
|
||||
return doc_id_filter
|
||||
if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}}
|
||||
|
||||
def _get_hierarchy_node_filter(
|
||||
node_ids: list[int],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for chunks whose ancestors include any of the given hierarchy nodes.
|
||||
|
||||
Uses a terms query to check if ancestor_hierarchy_node_ids contains
|
||||
any of the specified node IDs.
|
||||
) -> TermsQuery[int]:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
Returns a filter for chunks whose ancestors include any of the given
|
||||
hierarchy nodes.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
node_ids: The hierarchy node IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of hierarchy node IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the hierarchy node IDs.
|
||||
"""
|
||||
if not node_ids:
|
||||
raise ValueError(
|
||||
"node_ids cannot be empty if trying to create a hierarchy node ID filter."
|
||||
)
|
||||
if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}}
|
||||
|
||||
if document_id is not None and attached_document_ids is not None:
|
||||
raise ValueError(
|
||||
"document_id and attached_document_ids cannot be used together."
|
||||
)
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
@@ -1045,6 +1235,9 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
# Since this returns an isolated bool should clause, it can be
|
||||
# cached in OpenSearch independently of other clauses in
|
||||
# _get_search_filters.
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
|
||||
@@ -610,6 +610,22 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
return cleanup_content_for_chunks(query_vespa(params))
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
|
||||
@@ -11,7 +11,6 @@ class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@@ -690,9 +690,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@dotenvx/dotenvx/node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
@@ -9537,9 +9537,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
|
||||
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -11118,9 +11118,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tinyglobby/node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
106
backend/onyx/server/metrics/opensearch_search.py
Normal file
106
backend/onyx/server/metrics/opensearch_search.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Prometheus metrics for OpenSearch search latency and throughput.
|
||||
|
||||
Tracks client-side round-trip latency, server-side execution time (from
|
||||
OpenSearch's ``took`` field), total search count, and in-flight concurrency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SEARCH_LATENCY_BUCKETS = (
|
||||
0.005,
|
||||
0.01,
|
||||
0.025,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
25.0,
|
||||
)
|
||||
|
||||
_client_duration = Histogram(
|
||||
"onyx_opensearch_search_client_duration_seconds",
|
||||
"Client-side end-to-end latency of OpenSearch search calls",
|
||||
["search_type"],
|
||||
buckets=_SEARCH_LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
_server_duration = Histogram(
|
||||
"onyx_opensearch_search_server_duration_seconds",
|
||||
"Server-side execution time reported by OpenSearch (took field)",
|
||||
["search_type"],
|
||||
buckets=_SEARCH_LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
_search_total = Counter(
|
||||
"onyx_opensearch_search_total",
|
||||
"Total number of search requests sent to OpenSearch",
|
||||
["search_type"],
|
||||
)
|
||||
|
||||
_searches_in_progress = Gauge(
|
||||
"onyx_opensearch_searches_in_progress",
|
||||
"Number of OpenSearch searches currently in-flight",
|
||||
["search_type"],
|
||||
)
|
||||
|
||||
|
||||
def observe_opensearch_search(
|
||||
search_type: OpenSearchSearchType,
|
||||
client_duration_s: float,
|
||||
server_took_ms: int | None,
|
||||
) -> None:
|
||||
"""Records latency and throughput metrics for a completed OpenSearch search.
|
||||
|
||||
Args:
|
||||
search_type: The type of search.
|
||||
client_duration_s: Wall-clock duration measured on the client side, in
|
||||
seconds.
|
||||
server_took_ms: The ``took`` value from the OpenSearch response, in
|
||||
milliseconds. May be ``None`` if the response did not include it.
|
||||
"""
|
||||
try:
|
||||
label = search_type.value
|
||||
_search_total.labels(search_type=label).inc()
|
||||
_client_duration.labels(search_type=label).observe(client_duration_s)
|
||||
if server_took_ms is not None:
|
||||
_server_duration.labels(search_type=label).observe(server_took_ms / 1000.0)
|
||||
except Exception:
|
||||
logger.warning("Failed to record OpenSearch search metrics.", exc_info=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_opensearch_search_in_progress(
|
||||
search_type: OpenSearchSearchType,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that tracks in-flight OpenSearch searches via a Gauge."""
|
||||
incremented = False
|
||||
label = search_type.value
|
||||
try:
|
||||
_searches_in_progress.labels(search_type=label).inc()
|
||||
incremented = True
|
||||
except Exception:
|
||||
logger.warning("Failed to increment in-progress search gauge.", exc_info=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if incremented:
|
||||
try:
|
||||
_searches_in_progress.labels(search_type=label).dec()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrement in-progress search gauge.", exc_info=True
|
||||
)
|
||||
@@ -29,7 +29,6 @@ from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -83,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -573,38 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in run_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -695,26 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
_user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
try:
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -41,16 +41,6 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class MultiModelMessageResponseIDInfo(BaseModel):
|
||||
"""Sent at the start of a multi-model streaming response.
|
||||
Contains the user message ID and the reserved assistant message IDs
|
||||
for each model being run in parallel."""
|
||||
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_ids: list[int]
|
||||
model_names: list[str]
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -96,9 +86,6 @@ class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
llm_override: LLMOverride | None = None
|
||||
# For multi-model mode: up to 3 LLM overrides to run in parallel.
|
||||
# When provided with >1 entry, triggers multi-model streaming.
|
||||
llm_overrides: list[LLMOverride] | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
|
||||
@@ -224,8 +211,6 @@ class ChatMessageDetail(BaseModel):
|
||||
error: str | None = None
|
||||
current_feedback: str | None = None # "like" | "dislike" | null
|
||||
processing_duration_seconds: float | None = None
|
||||
preferred_response_id: int | None = None
|
||||
model_display_name: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -233,11 +218,6 @@ class ChatMessageDetail(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class SetPreferredResponseRequest(BaseModel):
|
||||
user_message_id: int
|
||||
preferred_response_id: int
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
|
||||
@@ -8,5 +8,3 @@ class Placement(BaseModel):
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -549,7 +549,7 @@ mypy-extensions==1.0.0
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.3
|
||||
nltk==3.9.4
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
# via
|
||||
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.9.1
|
||||
pypdf==6.9.2
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -861,7 +861,7 @@ regex==2025.11.3
|
||||
# dateparser
|
||||
# nltk
|
||||
# tiktoken
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# atlassian-python-api
|
||||
# braintrust
|
||||
|
||||
@@ -410,7 +410,7 @@ release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -244,7 +244,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -338,7 +338,7 @@ regex==2025.11.3
|
||||
# via
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -10,6 +10,9 @@ from typing import Any
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -17,6 +20,12 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
|
||||
HIERARCHY_NODE_ID = 42
|
||||
PERSONA_ID = 7
|
||||
KNOWLEDGE_FILTER_SCHEMA_FIELDS = {
|
||||
DOCUMENT_ID_FIELD_NAME,
|
||||
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME,
|
||||
DOCUMENT_SETS_FIELD_NAME,
|
||||
PERSONAS_FIELD_NAME,
|
||||
}
|
||||
|
||||
|
||||
def _get_search_filters(
|
||||
@@ -62,7 +71,26 @@ class TestAssistantKnowledgeFilter:
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
if (
|
||||
clause["bool"].get("minimum_should_match") == 1
|
||||
and len(clause["bool"]["should"]) > 0
|
||||
and (
|
||||
(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
or (
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
)
|
||||
):
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
@@ -96,7 +124,26 @@ class TestAssistantKnowledgeFilter:
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
if (
|
||||
clause["bool"].get("minimum_should_match") == 1
|
||||
and len(clause["bool"]["should"]) > 0
|
||||
and (
|
||||
(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
or (
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
)
|
||||
):
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
@@ -127,7 +174,26 @@ class TestAssistantKnowledgeFilter:
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
if (
|
||||
clause["bool"].get("minimum_should_match") == 1
|
||||
and len(clause["bool"]["should"]) > 0
|
||||
and (
|
||||
(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("term", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
or (
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
and list(
|
||||
clause["bool"]["should"][0].get("terms", {}).keys()
|
||||
)[0]
|
||||
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
|
||||
)
|
||||
)
|
||||
):
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
|
||||
@@ -974,7 +974,7 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
# Index documents with different public/hidden, ACL, and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
@@ -997,7 +997,7 @@ class TestOpenSearchClient:
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_emails=["user-a@example.com", "user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
@@ -1044,7 +1044,10 @@ class TestOpenSearchClient:
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
access_control_list=[
|
||||
prefix_user_email("user-a@example.com"),
|
||||
prefix_user_email("user-c@example.com"),
|
||||
],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
@@ -1661,7 +1664,7 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
# Index documents with different public/hidden, ACL, and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
@@ -1684,7 +1687,7 @@ class TestOpenSearchClient:
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_emails=["user-a@example.com", "user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
@@ -1746,7 +1749,10 @@ class TestOpenSearchClient:
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
access_control_list=[
|
||||
prefix_user_email("user-a@example.com"),
|
||||
prefix_user_email("user-c@example.com"),
|
||||
],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
@@ -1805,7 +1811,7 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
# Index documents with different public/hidden, ACL, and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
@@ -1831,7 +1837,7 @@ class TestOpenSearchClient:
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_emails=["user-a@example.com", "user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
@@ -1879,7 +1885,10 @@ class TestOpenSearchClient:
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
access_control_list=[
|
||||
prefix_user_email("user-a@example.com"),
|
||||
prefix_user_email("user-c@example.com"),
|
||||
],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
|
||||
@@ -80,6 +80,7 @@ def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
connector_specific_config={
|
||||
"sites": sharepoint_sites.split(","),
|
||||
"treat_sharing_link_as_public": True,
|
||||
},
|
||||
access_type=AccessType.SYNC, # Enable permission sync
|
||||
user_performing_action=admin_user,
|
||||
|
||||
@@ -8,6 +8,9 @@ import pytest
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_enumerate_ad_groups_paginated,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_is_public_item,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_iter_graph_collection,
|
||||
)
|
||||
@@ -334,3 +337,143 @@ def test_site_page_url_not_duplicated(
|
||||
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
|
||||
expected_relative_url
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_public_item – sharing link visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_permission(scope: str | None) -> MagicMock:
|
||||
perm = MagicMock()
|
||||
if scope is None:
|
||||
perm.link = None
|
||||
else:
|
||||
perm.link = MagicMock()
|
||||
perm.link.scope = scope
|
||||
return perm
|
||||
|
||||
|
||||
def _make_drive_item_with_permissions(
|
||||
permissions: list[MagicMock],
|
||||
) -> MagicMock:
|
||||
drive_item = MagicMock()
|
||||
drive_item.id = "item-123"
|
||||
drive_item.permissions.get_all.return_value = permissions
|
||||
return drive_item
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_anonymous_link_when_enabled(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
|
||||
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is True
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_org_link_when_enabled(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission("organization")])
|
||||
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is True
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_anonymous_link_when_disabled(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
"""When the flag is off, anonymous links do NOT make the item public."""
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
|
||||
assert _is_public_item(drive_item, treat_sharing_link_as_public=False) is False
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_org_link_when_disabled(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
"""When the flag is off, org links do NOT make the item public."""
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission("organization")])
|
||||
assert _is_public_item(drive_item, treat_sharing_link_as_public=False) is False
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_no_sharing_links(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
"""User-level permissions only — not public even when flag is on."""
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission(None)])
|
||||
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is False
|
||||
|
||||
|
||||
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
|
||||
def test_is_public_item_default_is_false(
|
||||
_mock_sleep: MagicMock,
|
||||
) -> None:
|
||||
"""Default value of the flag is False, so sharing links are ignored."""
|
||||
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
|
||||
assert _is_public_item(drive_item) is False
|
||||
|
||||
|
||||
def test_is_public_item_skips_api_call_when_disabled() -> None:
|
||||
"""When the flag is off, the permissions API is never called."""
|
||||
drive_item = MagicMock()
|
||||
_is_public_item(drive_item, treat_sharing_link_as_public=False)
|
||||
drive_item.permissions.get_all.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_access_from_sharepoint – sharing link integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{MODULE}._is_public_item", return_value=True)
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_drive_item_public_when_sharing_link_enabled(
|
||||
_mock_sleep: MagicMock,
|
||||
_mock_is_public: MagicMock,
|
||||
) -> None:
|
||||
"""With treat_sharing_link_as_public=True, a public item returns is_public=True
|
||||
and skips role-assignment resolution entirely."""
|
||||
drive_item = MagicMock()
|
||||
|
||||
result = get_external_access_from_sharepoint(
|
||||
client_context=MagicMock(),
|
||||
graph_client=MagicMock(),
|
||||
drive_name="Documents",
|
||||
drive_item=drive_item,
|
||||
site_page=None,
|
||||
treat_sharing_link_as_public=True,
|
||||
)
|
||||
|
||||
assert result.is_public is True
|
||||
assert result.external_user_emails == set()
|
||||
assert result.external_user_group_ids == set()
|
||||
|
||||
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
@patch(f"{MODULE}._is_public_item", return_value=False)
|
||||
def test_drive_item_falls_through_when_sharing_link_disabled(
|
||||
_mock_is_public: MagicMock,
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
) -> None:
|
||||
"""With treat_sharing_link_as_public=False, the function falls through to
|
||||
role-assignment-based permission resolution."""
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={"SiteMembers_abc": {"alice@contoso.com"}},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
result = get_external_access_from_sharepoint(
|
||||
client_context=MagicMock(),
|
||||
graph_client=MagicMock(),
|
||||
drive_name="Documents",
|
||||
drive_item=MagicMock(),
|
||||
site_page=None,
|
||||
treat_sharing_link_as_public=False,
|
||||
)
|
||||
|
||||
assert result.is_public is False
|
||||
assert len(result.external_user_group_ids) > 0
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in run_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = run_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.query.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.query.return_value.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.query.return_value.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
@@ -1,134 +0,0 @@
|
||||
"""Unit tests for multi-model answer generation types.
|
||||
|
||||
Tests cover:
|
||||
- Placement.model_index serialization
|
||||
- MultiModelMessageResponseIDInfo round-trip
|
||||
- SendMessageRequest.llm_overrides backward compatibility
|
||||
- ChatMessageDetail new fields
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
|
||||
class TestPlacementModelIndex:
|
||||
def test_default_none(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
assert p.model_index is None
|
||||
|
||||
def test_set_value(self) -> None:
|
||||
p = Placement(turn_index=0, model_index=2)
|
||||
assert p.model_index == 2
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
p = Placement(turn_index=0, tab_index=1, model_index=1)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] == 1
|
||||
|
||||
def test_none_excluded_when_default(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] is None
|
||||
|
||||
|
||||
class TestMultiModelMessageResponseIDInfo:
|
||||
def test_round_trip(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=42,
|
||||
reserved_assistant_message_ids=[43, 44, 45],
|
||||
model_names=["gpt-4", "claude-opus", "gemini-pro"],
|
||||
)
|
||||
d = info.model_dump()
|
||||
restored = MultiModelMessageResponseIDInfo(**d)
|
||||
assert restored.user_message_id == 42
|
||||
assert restored.reserved_assistant_message_ids == [43, 44, 45]
|
||||
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
|
||||
|
||||
def test_null_user_message_id(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=None,
|
||||
reserved_assistant_message_ids=[1, 2],
|
||||
model_names=["a", "b"],
|
||||
)
|
||||
assert info.user_message_id is None
|
||||
|
||||
|
||||
class TestSendMessageRequestOverrides:
|
||||
def test_llm_overrides_default_none(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
)
|
||||
assert req.llm_overrides is None
|
||||
|
||||
def test_llm_overrides_accepts_list(self) -> None:
|
||||
overrides = [
|
||||
LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
|
||||
]
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_overrides=overrides,
|
||||
)
|
||||
assert req.llm_overrides is not None
|
||||
assert len(req.llm_overrides) == 2
|
||||
|
||||
def test_backward_compat_single_override(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
)
|
||||
assert req.llm_override is not None
|
||||
assert req.llm_overrides is None
|
||||
|
||||
|
||||
class TestChatMessageDetailMultiModel:
|
||||
def test_defaults_none(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
)
|
||||
assert detail.preferred_response_id is None
|
||||
assert detail.model_display_name is None
|
||||
|
||||
def test_set_values(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.USER,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
preferred_response_id=42,
|
||||
model_display_name="GPT-4",
|
||||
)
|
||||
assert detail.preferred_response_id == 42
|
||||
assert detail.model_display_name == "GPT-4"
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
model_display_name="Claude Opus",
|
||||
)
|
||||
d = detail.model_dump()
|
||||
assert d["model_display_name"] == "Claude Opus"
|
||||
assert d["preferred_response_id"] is None
|
||||
@@ -145,6 +145,7 @@ def _mock_convert(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
include_permissions: bool = False, # noqa: ARG001
|
||||
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
|
||||
access_token: str | None = None, # noqa: ARG001
|
||||
treat_sharing_link_as_public: bool = False, # noqa: ARG001
|
||||
) -> Document:
|
||||
return _make_document(driveitem)
|
||||
|
||||
|
||||
215
backend/tests/unit/onyx/connectors/sharepoint/test_denylist.py
Normal file
215
backend/tests/unit/onyx/connectors/sharepoint/test_denylist.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.sharepoint.connector import _build_item_relative_path
|
||||
from onyx.connectors.sharepoint.connector import _is_path_excluded
|
||||
from onyx.connectors.sharepoint.connector import _is_site_excluded
|
||||
from onyx.connectors.sharepoint.connector import DriveItemData
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
|
||||
class TestIsSiteExcluded:
|
||||
def test_exact_match(self) -> None:
|
||||
assert _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/archive",
|
||||
["https://contoso.sharepoint.com/sites/archive"],
|
||||
)
|
||||
|
||||
def test_trailing_slash_mismatch(self) -> None:
|
||||
assert _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/archive/",
|
||||
["https://contoso.sharepoint.com/sites/archive"],
|
||||
)
|
||||
|
||||
def test_glob_wildcard(self) -> None:
|
||||
assert _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/archive-2024",
|
||||
["*/sites/archive-*"],
|
||||
)
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert not _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/engineering",
|
||||
["https://contoso.sharepoint.com/sites/archive"],
|
||||
)
|
||||
|
||||
def test_empty_patterns(self) -> None:
|
||||
assert not _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/engineering",
|
||||
[],
|
||||
)
|
||||
|
||||
def test_multiple_patterns(self) -> None:
|
||||
patterns = [
|
||||
"*/sites/archive-*",
|
||||
"*/sites/hr-confidential",
|
||||
]
|
||||
assert _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/hr-confidential",
|
||||
patterns,
|
||||
)
|
||||
assert not _is_site_excluded(
|
||||
"https://contoso.sharepoint.com/sites/engineering",
|
||||
patterns,
|
||||
)
|
||||
|
||||
|
||||
class TestIsPathExcluded:
|
||||
def test_filename_glob(self) -> None:
|
||||
assert _is_path_excluded("Engineering/report.tmp", ["*.tmp"])
|
||||
|
||||
def test_filename_only(self) -> None:
|
||||
assert _is_path_excluded("report.tmp", ["*.tmp"])
|
||||
|
||||
def test_office_lock_files(self) -> None:
|
||||
assert _is_path_excluded("Docs/~$document.docx", ["~$*"])
|
||||
|
||||
def test_folder_glob(self) -> None:
|
||||
assert _is_path_excluded("Archive/old/report.docx", ["Archive/*"])
|
||||
|
||||
def test_nested_folder_glob(self) -> None:
|
||||
assert _is_path_excluded("Projects/Archive/report.docx", ["*/Archive/*"])
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert not _is_path_excluded("Engineering/report.docx", ["*.tmp"])
|
||||
|
||||
def test_empty_patterns(self) -> None:
|
||||
assert not _is_path_excluded("anything.docx", [])
|
||||
|
||||
def test_multiple_patterns(self) -> None:
|
||||
patterns = ["*.tmp", "~$*", "Archive/*"]
|
||||
assert _is_path_excluded("test.tmp", patterns)
|
||||
assert _is_path_excluded("~$doc.docx", patterns)
|
||||
assert _is_path_excluded("Archive/old.pdf", patterns)
|
||||
assert not _is_path_excluded("Engineering/report.docx", patterns)
|
||||
|
||||
|
||||
class TestBuildItemRelativePath:
|
||||
def test_with_folder(self) -> None:
|
||||
assert (
|
||||
_build_item_relative_path(
|
||||
"/drives/abc/root:/Engineering/API", "report.docx"
|
||||
)
|
||||
== "Engineering/API/report.docx"
|
||||
)
|
||||
|
||||
def test_root_level(self) -> None:
|
||||
assert (
|
||||
_build_item_relative_path("/drives/abc/root:", "report.docx")
|
||||
== "report.docx"
|
||||
)
|
||||
|
||||
def test_none_parent(self) -> None:
|
||||
assert _build_item_relative_path(None, "report.docx") == "report.docx"
|
||||
|
||||
def test_percent_encoded_folder(self) -> None:
|
||||
assert (
|
||||
_build_item_relative_path("/drives/abc/root:/My%20Documents", "report.docx")
|
||||
== "My Documents/report.docx"
|
||||
)
|
||||
|
||||
def test_no_root_marker(self) -> None:
|
||||
assert _build_item_relative_path("/drives/abc", "report.docx") == "report.docx"
|
||||
|
||||
|
||||
class TestFilterExcludedSites:
|
||||
def test_filters_matching_sites(self) -> None:
|
||||
connector = SharepointConnector(
|
||||
excluded_sites=["*/sites/archive"],
|
||||
)
|
||||
descriptors = [
|
||||
SiteDescriptor(
|
||||
url="https://t.sharepoint.com/sites/archive",
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
),
|
||||
SiteDescriptor(
|
||||
url="https://t.sharepoint.com/sites/engineering",
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
),
|
||||
]
|
||||
result = connector._filter_excluded_sites(descriptors)
|
||||
assert len(result) == 1
|
||||
assert result[0].url == "https://t.sharepoint.com/sites/engineering"
|
||||
|
||||
def test_empty_excluded_returns_all(self) -> None:
|
||||
connector = SharepointConnector(excluded_sites=[])
|
||||
descriptors = [
|
||||
SiteDescriptor(
|
||||
url="https://t.sharepoint.com/sites/a",
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
),
|
||||
SiteDescriptor(
|
||||
url="https://t.sharepoint.com/sites/b",
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
),
|
||||
]
|
||||
result = connector._filter_excluded_sites(descriptors)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestIsDriveitemExcluded:
|
||||
def test_excluded_by_extension(self) -> None:
|
||||
connector = SharepointConnector(excluded_paths=["*.tmp"])
|
||||
item = DriveItemData(
|
||||
id="1",
|
||||
name="file.tmp",
|
||||
web_url="https://example.com/file.tmp",
|
||||
parent_reference_path="/drives/abc/root:/Docs",
|
||||
)
|
||||
assert connector._is_driveitem_excluded(item)
|
||||
|
||||
def test_not_excluded(self) -> None:
|
||||
connector = SharepointConnector(excluded_paths=["*.tmp"])
|
||||
item = DriveItemData(
|
||||
id="1",
|
||||
name="file.docx",
|
||||
web_url="https://example.com/file.docx",
|
||||
parent_reference_path="/drives/abc/root:/Docs",
|
||||
)
|
||||
assert not connector._is_driveitem_excluded(item)
|
||||
|
||||
def test_no_patterns_never_excludes(self) -> None:
|
||||
connector = SharepointConnector(excluded_paths=[])
|
||||
item = DriveItemData(
|
||||
id="1",
|
||||
name="file.tmp",
|
||||
web_url="https://example.com/file.tmp",
|
||||
parent_reference_path="/drives/abc/root:/Docs",
|
||||
)
|
||||
assert not connector._is_driveitem_excluded(item)
|
||||
|
||||
def test_folder_pattern(self) -> None:
|
||||
connector = SharepointConnector(excluded_paths=["Archive/*"])
|
||||
item = DriveItemData(
|
||||
id="1",
|
||||
name="old.pdf",
|
||||
web_url="https://example.com/old.pdf",
|
||||
parent_reference_path="/drives/abc/root:/Archive",
|
||||
)
|
||||
assert connector._is_driveitem_excluded(item)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"whitespace_pattern",
|
||||
["", " ", "\t"],
|
||||
)
|
||||
def test_whitespace_patterns_ignored(self, whitespace_pattern: str) -> None:
|
||||
connector = SharepointConnector(excluded_paths=[whitespace_pattern])
|
||||
assert connector.excluded_paths == []
|
||||
|
||||
def test_whitespace_padded_patterns_are_trimmed(self) -> None:
|
||||
connector = SharepointConnector(excluded_paths=[" *.tmp ", " Archive/* "])
|
||||
assert connector.excluded_paths == ["*.tmp", "Archive/*"]
|
||||
|
||||
item = DriveItemData(
|
||||
id="1",
|
||||
name="file.tmp",
|
||||
web_url="https://example.com/file.tmp",
|
||||
parent_reference_path="/drives/abc/root:/Docs",
|
||||
)
|
||||
assert connector._is_driveitem_excluded(item)
|
||||
@@ -211,6 +211,7 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
|
||||
include_permissions: bool, # noqa: ARG001
|
||||
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
|
||||
access_token: str | None = None, # noqa: ARG001
|
||||
treat_sharing_link_as_public: bool = False, # noqa: ARG001
|
||||
) -> Document:
|
||||
captured_drive_names.append(drive_name)
|
||||
return Document(
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Tests for OpenSearch search Prometheus metrics."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.server.metrics.opensearch_search import _client_duration
|
||||
from onyx.server.metrics.opensearch_search import _search_total
|
||||
from onyx.server.metrics.opensearch_search import _searches_in_progress
|
||||
from onyx.server.metrics.opensearch_search import _server_duration
|
||||
from onyx.server.metrics.opensearch_search import observe_opensearch_search
|
||||
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
|
||||
|
||||
|
||||
class TestObserveOpenSearchSearch:
|
||||
def test_increments_counter(self) -> None:
|
||||
search_type = OpenSearchSearchType.HYBRID
|
||||
before = _search_total.labels(search_type=search_type.value)._value.get()
|
||||
observe_opensearch_search(search_type, 0.1, 50)
|
||||
after = _search_total.labels(search_type=search_type.value)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_observes_client_duration(self) -> None:
|
||||
search_type = OpenSearchSearchType.KEYWORD
|
||||
before_sum = _client_duration.labels(search_type=search_type.value)._sum.get()
|
||||
observe_opensearch_search(search_type, 0.25, 100)
|
||||
after_sum = _client_duration.labels(search_type=search_type.value)._sum.get()
|
||||
assert after_sum == before_sum + 0.25
|
||||
|
||||
def test_observes_server_duration(self) -> None:
|
||||
search_type = OpenSearchSearchType.SEMANTIC
|
||||
before_sum = _server_duration.labels(search_type=search_type.value)._sum.get()
|
||||
observe_opensearch_search(search_type, 0.3, 200)
|
||||
after_sum = _server_duration.labels(search_type=search_type.value)._sum.get()
|
||||
# 200ms should be recorded as 0.2s.
|
||||
assert after_sum == before_sum + 0.2
|
||||
|
||||
def test_server_took_none_skips_server_histogram(self) -> None:
|
||||
search_type = OpenSearchSearchType.UNKNOWN
|
||||
before_server = _server_duration.labels(
|
||||
search_type=search_type.value
|
||||
)._sum.get()
|
||||
before_client = _client_duration.labels(
|
||||
search_type=search_type.value
|
||||
)._sum.get()
|
||||
before_total = _search_total.labels(search_type=search_type.value)._value.get()
|
||||
|
||||
observe_opensearch_search(search_type, 0.1, None)
|
||||
|
||||
# Server histogram should NOT be observed.
|
||||
after_server = _server_duration.labels(search_type=search_type.value)._sum.get()
|
||||
assert after_server == before_server
|
||||
|
||||
# Client histogram and counter should still work.
|
||||
after_client = _client_duration.labels(search_type=search_type.value)._sum.get()
|
||||
after_total = _search_total.labels(search_type=search_type.value)._value.get()
|
||||
assert after_client == before_client + 0.1
|
||||
assert after_total == before_total + 1
|
||||
|
||||
def test_exceptions_do_not_propagate(self) -> None:
|
||||
search_type = OpenSearchSearchType.RANDOM
|
||||
with patch.object(
|
||||
_search_total.labels(search_type=search_type.value),
|
||||
"inc",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
# Should not raise.
|
||||
observe_opensearch_search(search_type, 0.1, 50)
|
||||
|
||||
|
||||
class TestTrackOpenSearchSearchInProgress:
|
||||
def test_gauge_increments_and_decrements(self) -> None:
|
||||
search_type = OpenSearchSearchType.HYBRID
|
||||
before = _searches_in_progress.labels(
|
||||
search_type=search_type.value
|
||||
)._value.get()
|
||||
|
||||
with track_opensearch_search_in_progress(search_type):
|
||||
during = _searches_in_progress.labels(
|
||||
search_type=search_type.value
|
||||
)._value.get()
|
||||
assert during == before + 1
|
||||
|
||||
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
|
||||
assert after == before
|
||||
|
||||
def test_gauge_decrements_on_exception(self) -> None:
|
||||
search_type = OpenSearchSearchType.SEMANTIC
|
||||
before = _searches_in_progress.labels(
|
||||
search_type=search_type.value
|
||||
)._value.get()
|
||||
|
||||
raised = False
|
||||
try:
|
||||
with track_opensearch_search_in_progress(search_type):
|
||||
raise ValueError("simulated search failure")
|
||||
except ValueError:
|
||||
raised = True
|
||||
assert raised
|
||||
|
||||
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
|
||||
assert after == before
|
||||
|
||||
def test_inc_exception_does_not_break_search(self) -> None:
|
||||
search_type = OpenSearchSearchType.KEYWORD
|
||||
before = _searches_in_progress.labels(
|
||||
search_type=search_type.value
|
||||
)._value.get()
|
||||
|
||||
with patch.object(
|
||||
_searches_in_progress.labels(search_type=search_type.value),
|
||||
"inc",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
# Context manager should still yield without decrementing.
|
||||
with track_opensearch_search_in_progress(search_type):
|
||||
# Search logic would execute here.
|
||||
during = _searches_in_progress.labels(
|
||||
search_type=search_type.value
|
||||
)._value.get()
|
||||
assert during == before
|
||||
|
||||
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
|
||||
assert after == before
|
||||
@@ -169,6 +169,21 @@ Engine label values: `sync` (main read-write), `async` (async sessions), `readon
|
||||
|
||||
Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`.
|
||||
|
||||
## OpenSearch Search Metrics
|
||||
|
||||
These metrics track OpenSearch search latency and throughput. Collected via `onyx.server.metrics.opensearch_search`.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_opensearch_search_client_duration_seconds` | Histogram | `search_type` | Client-side end-to-end latency (network + serialization + server execution) |
|
||||
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
|
||||
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
|
||||
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
|
||||
|
||||
Search type label values: See `OpenSearchSearchType`.
|
||||
|
||||
---
|
||||
|
||||
## Example PromQL Queries
|
||||
|
||||
### Which endpoints are saturated right now?
|
||||
@@ -258,3 +273,33 @@ histogram_quantile(0.99, sum by (handler, le) (rate(onyx_db_connection_hold_seco
|
||||
# Checkouts per second by engine
|
||||
sum by (engine) (rate(onyx_db_pool_checkout_total[5m]))
|
||||
```
|
||||
|
||||
### OpenSearch P99 search latency by type
|
||||
|
||||
```promql
|
||||
# P99 client-side latency by search type
|
||||
histogram_quantile(0.99, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))
|
||||
```
|
||||
|
||||
### OpenSearch search throughput
|
||||
|
||||
```promql
|
||||
# Searches per second by type
|
||||
sum by (search_type) (rate(onyx_opensearch_search_total[5m]))
|
||||
```
|
||||
|
||||
### OpenSearch concurrent searches
|
||||
|
||||
```promql
|
||||
# Total in-flight searches across all instances
|
||||
sum(onyx_opensearch_searches_in_progress)
|
||||
```
|
||||
|
||||
### OpenSearch network overhead
|
||||
|
||||
```promql
|
||||
# Difference between client and server P50 reveals network/serialization cost.
|
||||
histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))
|
||||
-
|
||||
histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))
|
||||
```
|
||||
|
||||
12
examples/widget/package-lock.json
generated
12
examples/widget/package-lock.json
generated
@@ -6271,9 +6271,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
|
||||
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
@@ -7179,9 +7179,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tinyglobby/node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.9.1",
|
||||
"pypdf==6.9.2",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
@@ -100,7 +100,7 @@ backend = [
|
||||
"python-multipart==0.0.22",
|
||||
"pywikibot==9.0.0",
|
||||
"redis==5.0.8",
|
||||
"requests==2.32.5",
|
||||
"requests==2.33.0",
|
||||
"requests-oauthlib==1.3.1",
|
||||
"rfc3986==1.5.0",
|
||||
"simple-salesforce==1.12.6",
|
||||
|
||||
22
uv.lock
generated
22
uv.lock
generated
@@ -3909,7 +3909,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.9.3"
|
||||
version = "3.9.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
@@ -3917,9 +3917,9 @@ dependencies = [
|
||||
{ name = "regex" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4481,7 +4481,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.1" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.2" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -4502,7 +4502,7 @@ requires-dist = [
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.33.0" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
{ name = "retry", specifier = "==0.9.2" },
|
||||
{ name = "rfc3986", marker = "extra == 'backend'", specifier = "==1.5.0" },
|
||||
@@ -5727,11 +5727,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.9.1"
|
||||
version = "6.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6378,7 +6378,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
version = "2.33.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
@@ -6386,9 +6386,9 @@ dependencies = [
|
||||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -281,35 +281,90 @@ If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
## 3. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
**Use the Opal `Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It uses
|
||||
string-enum props (`font` and `color`) for font preset and color selection. Inline markdown is
|
||||
opt-in via the `markdown()` function from `@opal/types`.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
// ✅ Good — Opal Text with string-enum props
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
<Text font="main-ui-action" color="text-03">
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
// ✅ Good — inline markdown via markdown()
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
<Text font="main-ui-body" color="text-05">
|
||||
{markdown("*Hello*, **world**! Visit [Onyx](https://onyx.app) and run `onyx start`.")}
|
||||
</Text>
|
||||
|
||||
// ✅ Good — plain strings are never parsed as markdown
|
||||
<Text font="main-ui-body" color="text-03">
|
||||
{userProvidedString}
|
||||
</Text>
|
||||
|
||||
// ✅ Good — component props that support optional markdown use `string | RichStr`
|
||||
import type { RichStr } from "@opal/types";
|
||||
|
||||
interface MyCardProps {
|
||||
title: string | RichStr;
|
||||
}
|
||||
|
||||
// ❌ Bad — legacy boolean-flag API (still works but deprecated)
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
<Text text03 mainUiAction>{name}</Text>
|
||||
|
||||
// ❌ Bad — naked text nodes
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
```
|
||||
|
||||
Key props:
|
||||
- `font`: `TextFont` — font preset (e.g., `"main-ui-body"`, `"heading-h2"`, `"secondary-action"`)
|
||||
- `color`: `TextColor` — text color (e.g., `"text-03"`, `"text-inverted-05"`)
|
||||
- `as`: `"p" | "span" | "li" | "h1" | "h2" | "h3"` — HTML tag (default: `"span"`)
|
||||
- `nowrap`: `boolean` — prevent text wrapping
|
||||
|
||||
**`RichStr` convention:** When creating new components, any string prop that will be rendered as
|
||||
visible text in the DOM (e.g., `title`, `description`, `label`) should be typed as
|
||||
`string | RichStr` instead of plain `string`. This gives callers opt-in markdown support via
|
||||
`markdown()` without requiring any additional props or API surface on the component.
|
||||
|
||||
```typescript
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
|
||||
// ✅ Good — new components accept string | RichStr
|
||||
interface InfoCardProps {
|
||||
title: string | RichStr;
|
||||
description?: string | RichStr;
|
||||
}
|
||||
|
||||
function InfoCard({ title, description }: InfoCardProps) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
<Text font="main-ui-action">{resolveStr(title)}</Text>
|
||||
{description && (
|
||||
<Text font="secondary-body" color="text-03">{resolveStr(description)}</Text>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// ❌ Bad — plain string props block markdown support for callers
|
||||
interface InfoCardProps {
|
||||
title: string;
|
||||
description?: string;
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -33,6 +33,14 @@ export {
|
||||
type LineItemButtonProps,
|
||||
} from "@opal/components/buttons/line-item-button/components";
|
||||
|
||||
/* Text */
|
||||
export {
|
||||
Text,
|
||||
type TextProps,
|
||||
type TextFont,
|
||||
type TextColor,
|
||||
} from "@opal/components/text/components";
|
||||
|
||||
/* Tag */
|
||||
export {
|
||||
Tag,
|
||||
|
||||
76
web/lib/opal/src/components/text/InlineMarkdown.tsx
Normal file
76
web/lib/opal/src/components/text/InlineMarkdown.tsx
Normal file
@@ -0,0 +1,76 @@
|
||||
import type { ReactNode } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
|
||||
import type { RichStr } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// InlineMarkdown
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const SAFE_PROTOCOL = /^https?:|^mailto:|^tel:/i;
|
||||
|
||||
const ALLOWED_ELEMENTS = ["p", "a", "strong", "em", "code", "del"];
|
||||
|
||||
const INLINE_COMPONENTS = {
|
||||
p: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
a: ({ children, href }: { children?: ReactNode; href?: string }) => {
|
||||
if (!href || !SAFE_PROTOCOL.test(href)) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
const isHttp = /^https?:/i.test(href);
|
||||
return (
|
||||
<a
|
||||
href={href}
|
||||
className="underline underline-offset-2"
|
||||
{...(isHttp ? { target: "_blank", rel: "noopener noreferrer" } : {})}
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
},
|
||||
code: ({ children }: { children?: ReactNode }) => (
|
||||
<code className="[font-family:var(--font-dm-mono)] bg-background-tint-02 rounded px-1 py-0.5">
|
||||
{children}
|
||||
</code>
|
||||
),
|
||||
};
|
||||
|
||||
interface InlineMarkdownProps {
|
||||
content: string;
|
||||
}
|
||||
|
||||
export default function InlineMarkdown({ content }: InlineMarkdownProps) {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
components={INLINE_COMPONENTS}
|
||||
allowedElements={ALLOWED_ELEMENTS}
|
||||
unwrapDisallowed
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RichStr helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function isRichStr(value: unknown): value is RichStr {
|
||||
return (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
(value as RichStr).__brand === "RichStr"
|
||||
);
|
||||
}
|
||||
|
||||
/** Resolves `string | RichStr` to a `ReactNode`. */
|
||||
export function resolveStr(value: string | RichStr): ReactNode {
|
||||
return isRichStr(value) ? <InlineMarkdown content={value.raw} /> : value;
|
||||
}
|
||||
|
||||
/** Extracts the plain string from `string | RichStr`. */
|
||||
export function toPlainString(value: string | RichStr): string {
|
||||
return isRichStr(value) ? value.raw : value;
|
||||
}
|
||||
124
web/lib/opal/src/components/text/README.md
Normal file
124
web/lib/opal/src/components/text/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# Text
|
||||
|
||||
**Import:** `import { Text, type TextProps, type TextFont, type TextColor } from "@opal/components";`
|
||||
|
||||
A styled text component with string-enum props for font preset and color selection. Supports
|
||||
inline markdown rendering via `RichStr` — pass `markdown("*bold* text")` as children to enable.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `font` | `TextFont` | `"main-ui-body"` | Font preset (size, weight, line-height) |
|
||||
| `color` | `TextColor` | `"text-04"` | Text color |
|
||||
| `as` | `"p" \| "span" \| "li" \| "h1" \| "h2" \| "h3"` | `"span"` | HTML tag to render |
|
||||
| `nowrap` | `boolean` | `false` | Prevent text wrapping |
|
||||
| `children` | `string \| RichStr` | — | Plain string or `markdown()` for inline markdown |
|
||||
|
||||
### `TextFont`
|
||||
|
||||
| Value | Size | Weight | Line-height |
|
||||
|---|---|---|---|
|
||||
| `"heading-h1"` | 48px | 600 | 64px |
|
||||
| `"heading-h2"` | 24px | 600 | 36px |
|
||||
| `"heading-h3"` | 18px | 600 | 28px |
|
||||
| `"heading-h3-muted"` | 18px | 500 | 28px |
|
||||
| `"main-content-body"` | 16px | 450 | 24px |
|
||||
| `"main-content-muted"` | 16px | 400 | 24px |
|
||||
| `"main-content-emphasis"` | 16px | 700 | 24px |
|
||||
| `"main-content-mono"` | 16px | 400 | 23px |
|
||||
| `"main-ui-body"` | 14px | 500 | 20px |
|
||||
| `"main-ui-muted"` | 14px | 400 | 20px |
|
||||
| `"main-ui-action"` | 14px | 600 | 20px |
|
||||
| `"main-ui-mono"` | 14px | 400 | 20px |
|
||||
| `"secondary-body"` | 12px | 400 | 18px |
|
||||
| `"secondary-action"` | 12px | 600 | 18px |
|
||||
| `"secondary-mono"` | 12px | 400 | 18px |
|
||||
| `"figure-small-label"` | 10px | 600 | 14px |
|
||||
| `"figure-small-value"` | 10px | 400 | 14px |
|
||||
| `"figure-keystroke"` | 11px | 400 | 16px |
|
||||
|
||||
### `TextColor`
|
||||
|
||||
`"text-01" | "text-02" | "text-03" | "text-04" | "text-05" | "text-inverted-01" | "text-inverted-02" | "text-inverted-03" | "text-inverted-04" | "text-inverted-05" | "text-light-03" | "text-light-05" | "text-dark-03" | "text-dark-05"`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
```tsx
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
// Basic
|
||||
<Text font="main-ui-body" color="text-03">
|
||||
Hello world
|
||||
</Text>
|
||||
|
||||
// Heading
|
||||
<Text font="heading-h2" color="text-05" as="h2">
|
||||
Page Title
|
||||
</Text>
|
||||
|
||||
// Inverted (for dark backgrounds)
|
||||
<Text font="main-ui-body" color="text-inverted-05">
|
||||
Light text on dark
|
||||
</Text>
|
||||
|
||||
// As paragraph
|
||||
<Text font="main-content-body" color="text-03" as="p">
|
||||
A full paragraph of text.
|
||||
</Text>
|
||||
```
|
||||
|
||||
## Inline Markdown via `RichStr`
|
||||
|
||||
Inline markdown is opt-in via the `markdown()` function, which returns a `RichStr`. When `Text`
|
||||
receives a `RichStr` as children, it parses the inner string as inline markdown. Plain strings
|
||||
are rendered as-is — no parsing, no surprises. `Text` does not accept arbitrary JSX as children;
|
||||
use `string | RichStr` only.
|
||||
|
||||
```tsx
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
// Inline markdown — bold, italic, links, code, strikethrough
|
||||
<Text font="main-ui-body" color="text-05">
|
||||
{markdown("*Hello*, **world**! Visit [Onyx](https://onyx.app) and run `onyx start`.")}
|
||||
</Text>
|
||||
|
||||
// Plain string — no markdown parsing
|
||||
<Text font="main-ui-body" color="text-03">
|
||||
This *stays* as-is, no formatting applied.
|
||||
</Text>
|
||||
```
|
||||
|
||||
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`.
|
||||
|
||||
Markdown rendering uses `react-markdown` internally, restricted to inline elements only.
|
||||
`http(s)` links open in a new tab; `mailto:` and `tel:` links open natively. Inline code
|
||||
inherits the parent font size and switches to the monospace family.
|
||||
|
||||
**Note:** This is inline-only markdown. Multi-paragraph content (`"Hello\n\nWorld"`) will
|
||||
collapse into a single run of text since paragraph wrappers are stripped. For block-level
|
||||
markdown, use `MinimalMarkdown` instead.
|
||||
|
||||
### Using `RichStr` in component props
|
||||
|
||||
Components that want to support optional markdown in their text props should accept
|
||||
`string | RichStr`:
|
||||
|
||||
```tsx
|
||||
import type { RichStr } from "@opal/types";
|
||||
|
||||
interface MyComponentProps {
|
||||
title: string | RichStr;
|
||||
description?: string | RichStr;
|
||||
}
|
||||
```
|
||||
|
||||
This avoids API coloring — no `markdown` boolean needs to be threaded through intermediate
|
||||
components. The decision to use markdown lives at the call site.
|
||||
|
||||
## Compatibility
|
||||
|
||||
`@/refresh-components/texts/Text` is an independent legacy component that implements the same
|
||||
font/color presets via a boolean-flag API. It is **not** a wrapper around this component. New
|
||||
code should import directly from `@opal/components`.
|
||||
257
web/lib/opal/src/components/text/Text.stories.tsx
Normal file
257
web/lib/opal/src/components/text/Text.stories.tsx
Normal file
@@ -0,0 +1,257 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Text } from "@opal/components";
|
||||
import type { TextFont, TextColor } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
const meta: Meta<typeof Text> = {
|
||||
title: "opal/components/Text",
|
||||
component: Text,
|
||||
tags: ["autodocs"],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof Text>;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Basic
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
children: "The quick brown fox jumps over the lazy dog",
|
||||
},
|
||||
};
|
||||
|
||||
export const AsHeading: Story = {
|
||||
args: {
|
||||
font: "heading-h2",
|
||||
color: "text-05",
|
||||
as: "h2",
|
||||
children: "Page Title",
|
||||
},
|
||||
};
|
||||
|
||||
export const AsParagraph: Story = {
|
||||
args: {
|
||||
font: "main-content-body",
|
||||
color: "text-03",
|
||||
as: "p",
|
||||
children: "A full paragraph of body text rendered as a p element.",
|
||||
},
|
||||
};
|
||||
|
||||
export const Nowrap: Story = {
|
||||
render: () => (
|
||||
<div className="w-48 border border-border-02 rounded p-2">
|
||||
<Text font="main-ui-body" color="text-05" nowrap>
|
||||
This text will not wrap even though the container is narrow
|
||||
</Text>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fonts
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ALL_FONTS: TextFont[] = [
|
||||
"heading-h1",
|
||||
"heading-h2",
|
||||
"heading-h3",
|
||||
"heading-h3-muted",
|
||||
"main-content-body",
|
||||
"main-content-muted",
|
||||
"main-content-emphasis",
|
||||
"main-content-mono",
|
||||
"main-ui-body",
|
||||
"main-ui-muted",
|
||||
"main-ui-action",
|
||||
"main-ui-mono",
|
||||
"secondary-body",
|
||||
"secondary-action",
|
||||
"secondary-mono",
|
||||
"figure-small-label",
|
||||
"figure-small-value",
|
||||
"figure-keystroke",
|
||||
];
|
||||
|
||||
export const AllFonts: Story = {
|
||||
render: () => (
|
||||
<div className="space-y-2">
|
||||
{ALL_FONTS.map((font) => (
|
||||
<div key={font} className="flex items-baseline gap-4">
|
||||
<span className="w-56 shrink-0 font-secondary-body text-text-03">
|
||||
{font}
|
||||
</span>
|
||||
<Text font={font} color="text-05">
|
||||
The quick brown fox
|
||||
</Text>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Colors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const STANDARD_COLORS: TextColor[] = [
|
||||
"text-01",
|
||||
"text-02",
|
||||
"text-03",
|
||||
"text-04",
|
||||
"text-05",
|
||||
];
|
||||
|
||||
const INVERTED_COLORS: TextColor[] = [
|
||||
"text-inverted-01",
|
||||
"text-inverted-02",
|
||||
"text-inverted-03",
|
||||
"text-inverted-04",
|
||||
"text-inverted-05",
|
||||
];
|
||||
|
||||
export const AllColors: Story = {
|
||||
render: () => (
|
||||
<div className="space-y-2">
|
||||
{STANDARD_COLORS.map((color) => (
|
||||
<div key={color} className="flex items-baseline gap-4">
|
||||
<span className="w-56 shrink-0 font-secondary-body text-text-03">
|
||||
{color}
|
||||
</span>
|
||||
<Text font="main-ui-body" color={color}>
|
||||
The quick brown fox
|
||||
</Text>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const InvertedColors: Story = {
|
||||
render: () => (
|
||||
<div className="bg-background-inverted-01 rounded-lg p-6 space-y-2">
|
||||
{INVERTED_COLORS.map((color) => (
|
||||
<div key={color} className="flex items-baseline gap-4">
|
||||
<span
|
||||
className="w-56 shrink-0 font-secondary-body"
|
||||
style={{ color: "rgba(255,255,255,0.5)" }}
|
||||
>
|
||||
{color}
|
||||
</span>
|
||||
<Text font="main-ui-body" color={color}>
|
||||
The quick brown fox
|
||||
</Text>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Markdown via RichStr
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const MarkdownBold: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown("This is **bold** text"),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownItalic: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown("This is *italic* text"),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownCode: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown("Run `npm install` to get started"),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownLink: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown("Visit [Onyx](https://www.onyx.app/) for more info"),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownStrikethrough: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown("This is ~~deleted~~ text"),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownCombined: Story = {
|
||||
args: {
|
||||
font: "main-ui-body",
|
||||
color: "text-05",
|
||||
children: markdown(
|
||||
"*Hello*, **world**! Check out [Onyx](https://www.onyx.app/) and run `onyx start` to begin."
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownAtDifferentSizes: Story = {
|
||||
render: () => (
|
||||
<div className="space-y-3">
|
||||
<Text font="heading-h2" color="text-05" as="h2">
|
||||
{markdown("**Heading** with *emphasis* and `code`")}
|
||||
</Text>
|
||||
<Text font="main-content-body" color="text-03" as="p">
|
||||
{markdown("**Main content** with *emphasis* and `code`")}
|
||||
</Text>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{markdown("**Secondary** with *emphasis* and `code`")}
|
||||
</Text>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const PlainStringNotParsed: Story = {
|
||||
render: () => (
|
||||
<div className="space-y-2">
|
||||
<Text font="main-ui-body" color="text-05">
|
||||
{
|
||||
"This has *asterisks* and **double asterisks** but they are NOT parsed."
|
||||
}
|
||||
</Text>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tag Variants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const TagVariants: Story = {
|
||||
render: () => (
|
||||
<div className="space-y-2">
|
||||
<Text font="main-ui-body" color="text-05">
|
||||
Default (span): inline text
|
||||
</Text>
|
||||
<Text font="main-ui-body" color="text-05" as="p">
|
||||
Paragraph (p): block text
|
||||
</Text>
|
||||
<Text font="heading-h2" color="text-05" as="h2">
|
||||
Heading (h2): semantic heading
|
||||
</Text>
|
||||
<ul className="list-disc pl-6">
|
||||
<Text font="main-ui-body" color="text-05" as="li">
|
||||
List item (li): inside a list
|
||||
</Text>
|
||||
</ul>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
134
web/lib/opal/src/components/text/components.tsx
Normal file
134
web/lib/opal/src/components/text/components.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
import type { HTMLAttributes } from "react";
|
||||
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type TextFont =
|
||||
| "heading-h1"
|
||||
| "heading-h2"
|
||||
| "heading-h3"
|
||||
| "heading-h3-muted"
|
||||
| "main-content-body"
|
||||
| "main-content-muted"
|
||||
| "main-content-emphasis"
|
||||
| "main-content-mono"
|
||||
| "main-ui-body"
|
||||
| "main-ui-muted"
|
||||
| "main-ui-action"
|
||||
| "main-ui-mono"
|
||||
| "secondary-body"
|
||||
| "secondary-action"
|
||||
| "secondary-mono"
|
||||
| "figure-small-label"
|
||||
| "figure-small-value"
|
||||
| "figure-keystroke";
|
||||
|
||||
type TextColor =
|
||||
| "text-01"
|
||||
| "text-02"
|
||||
| "text-03"
|
||||
| "text-04"
|
||||
| "text-05"
|
||||
| "text-inverted-01"
|
||||
| "text-inverted-02"
|
||||
| "text-inverted-03"
|
||||
| "text-inverted-04"
|
||||
| "text-inverted-05"
|
||||
| "text-light-03"
|
||||
| "text-light-05"
|
||||
| "text-dark-03"
|
||||
| "text-dark-05";
|
||||
|
||||
interface TextProps
|
||||
extends WithoutStyles<
|
||||
Omit<HTMLAttributes<HTMLElement>, "color" | "children">
|
||||
> {
|
||||
/** Font preset. Default: `"main-ui-body"`. */
|
||||
font?: TextFont;
|
||||
|
||||
/** Color variant. Default: `"text-04"`. */
|
||||
color?: TextColor;
|
||||
|
||||
/** HTML tag to render. Default: `"span"`. */
|
||||
as?: "p" | "span" | "li" | "h1" | "h2" | "h3";
|
||||
|
||||
/** Prevent text wrapping. */
|
||||
nowrap?: boolean;
|
||||
|
||||
/** Plain string or `markdown()` for inline markdown. */
|
||||
children?: string | RichStr;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const FONT_CONFIG: Record<TextFont, string> = {
|
||||
"heading-h1": "font-heading-h1",
|
||||
"heading-h2": "font-heading-h2",
|
||||
"heading-h3": "font-heading-h3",
|
||||
"heading-h3-muted": "font-heading-h3-muted",
|
||||
"main-content-body": "font-main-content-body",
|
||||
"main-content-muted": "font-main-content-muted",
|
||||
"main-content-emphasis": "font-main-content-emphasis",
|
||||
"main-content-mono": "font-main-content-mono",
|
||||
"main-ui-body": "font-main-ui-body",
|
||||
"main-ui-muted": "font-main-ui-muted",
|
||||
"main-ui-action": "font-main-ui-action",
|
||||
"main-ui-mono": "font-main-ui-mono",
|
||||
"secondary-body": "font-secondary-body",
|
||||
"secondary-action": "font-secondary-action",
|
||||
"secondary-mono": "font-secondary-mono",
|
||||
"figure-small-label": "font-figure-small-label",
|
||||
"figure-small-value": "font-figure-small-value",
|
||||
"figure-keystroke": "font-figure-keystroke",
|
||||
};
|
||||
|
||||
const COLOR_CONFIG: Record<TextColor, string> = {
|
||||
"text-01": "text-text-01",
|
||||
"text-02": "text-text-02",
|
||||
"text-03": "text-text-03",
|
||||
"text-04": "text-text-04",
|
||||
"text-05": "text-text-05",
|
||||
"text-inverted-01": "text-text-inverted-01",
|
||||
"text-inverted-02": "text-text-inverted-02",
|
||||
"text-inverted-03": "text-text-inverted-03",
|
||||
"text-inverted-04": "text-text-inverted-04",
|
||||
"text-inverted-05": "text-text-inverted-05",
|
||||
"text-light-03": "text-text-light-03",
|
||||
"text-light-05": "text-text-light-05",
|
||||
"text-dark-03": "text-text-dark-03",
|
||||
"text-dark-05": "text-text-dark-05",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Text
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function Text({
|
||||
font = "main-ui-body",
|
||||
color = "text-04",
|
||||
as: Tag = "span",
|
||||
nowrap,
|
||||
children,
|
||||
...rest
|
||||
}: TextProps) {
|
||||
const resolvedClassName = cn(
|
||||
FONT_CONFIG[font],
|
||||
COLOR_CONFIG[color],
|
||||
nowrap && "whitespace-nowrap"
|
||||
);
|
||||
|
||||
return (
|
||||
<Tag {...rest} className={resolvedClassName}>
|
||||
{children && resolveStr(children)}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
export { Text, type TextProps, type TextFont, type TextColor };
|
||||
@@ -55,6 +55,9 @@ interface ContentMdProps {
|
||||
/** When `true`, renders "(Optional)" beside the title. */
|
||||
optional?: boolean;
|
||||
|
||||
/** Custom muted suffix rendered beside the title. */
|
||||
titleSuffix?: string;
|
||||
|
||||
/** Auxiliary status icon rendered beside the title. */
|
||||
auxIcon?: ContentMdAuxIcon;
|
||||
|
||||
@@ -138,6 +141,7 @@ function ContentMd({
|
||||
editable,
|
||||
onTitleChange,
|
||||
optional,
|
||||
titleSuffix,
|
||||
auxIcon,
|
||||
tag,
|
||||
sizePreset = "main-ui",
|
||||
@@ -234,12 +238,12 @@ function ContentMd({
|
||||
</span>
|
||||
)}
|
||||
|
||||
{optional && (
|
||||
{(optional || titleSuffix) && (
|
||||
<span
|
||||
className={cn(config.optionalFont, "text-text-03 shrink-0")}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
(Optional)
|
||||
{titleSuffix ?? "(Optional)"}
|
||||
</span>
|
||||
)}
|
||||
|
||||
|
||||
@@ -96,6 +96,8 @@ type MdContentProps = ContentBaseProps & {
|
||||
variant?: "section";
|
||||
/** When `true`, renders "(Optional)" beside the title in the muted font variant. */
|
||||
optional?: boolean;
|
||||
/** Custom muted suffix rendered beside the title. */
|
||||
titleSuffix?: string;
|
||||
/** Auxiliary status icon rendered beside the title. */
|
||||
auxIcon?: "info-gray" | "info-blue" | "warning" | "error";
|
||||
/** Tag rendered beside the title. */
|
||||
|
||||
@@ -86,6 +86,26 @@ export interface IconProps extends SVGProps<SVGSVGElement> {
|
||||
/** Strips `className` and `style` from a props type to enforce design-system styling. */
|
||||
export type WithoutStyles<T> = Omit<T, "className" | "style">;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Rich Strings
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* A branded string wrapper that signals inline markdown should be parsed.
|
||||
*
|
||||
* Created via the `markdown()` function. Components that accept `string | RichStr`
|
||||
* will parse the inner `raw` string as inline markdown when a `RichStr` is passed,
|
||||
* and render plain text when a regular `string` is passed.
|
||||
*
|
||||
* This avoids "API coloring" — components don't need a `markdown` boolean prop,
|
||||
* and intermediate wrappers don't need to thread it through. The decision to
|
||||
* use markdown lives at the call site via `markdown("*bold* text")`.
|
||||
*/
|
||||
export interface RichStr {
|
||||
readonly __brand: "RichStr";
|
||||
readonly raw: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* HTML button `type` attribute values.
|
||||
*
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { clsx, type ClassValue } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
import type { RichStr } from "@opal/types";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
/** Wraps a string for inline markdown parsing by `Text` and other Opal components. */
|
||||
export function markdown(content: string): RichStr {
|
||||
return { __brand: "RichStr", raw: content };
|
||||
}
|
||||
|
||||
4
web/package-lock.json
generated
4
web/package-lock.json
generated
@@ -8871,7 +8871,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/cosmiconfig/node_modules/yaml": {
|
||||
"version": "1.10.2",
|
||||
"version": "1.10.3",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.3.tgz",
|
||||
"integrity": "sha512-vIYeF1u3CjlhAFekPPAk2h/Kv4T3mAkMox5OymRiJQB0spDP10LHvt+K7G9Ny6NuuMAb25/6n1qyUjAcGNf/AA==",
|
||||
"license": "ISC",
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
|
||||
@@ -30,7 +30,7 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
@@ -87,10 +87,19 @@ function Main() {
|
||||
</CreateButton>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<Text as="p" text04>
|
||||
This feature requires an active paid subscription.
|
||||
</Text>
|
||||
<Button href="/admin/billing">Upgrade Plan</Button>
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text as="p" text04>
|
||||
Upgrade to a paid plan to create API keys.
|
||||
</Text>
|
||||
<Button
|
||||
variant="none"
|
||||
prominence="tertiary"
|
||||
size="2xs"
|
||||
icon={SvgInfo}
|
||||
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
</div>
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -159,10 +159,6 @@ export interface Message {
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
|
||||
// Multi-model answer generation
|
||||
preferredResponseId?: number | null;
|
||||
modelDisplayName?: string | null;
|
||||
|
||||
// new gen
|
||||
packets: Packet[];
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
@@ -235,9 +231,6 @@ export interface BackendMessage {
|
||||
parentMessageId: number | null;
|
||||
refined_answer_improvement: boolean | null;
|
||||
is_agentic: boolean | null;
|
||||
// Multi-model answer generation
|
||||
preferred_response_id: number | null;
|
||||
model_display_name: string | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
@@ -245,12 +238,6 @@ export interface MessageResponseIDInfo {
|
||||
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
|
||||
}
|
||||
|
||||
export interface MultiModelMessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
reserved_assistant_message_ids: number[];
|
||||
model_names: string[];
|
||||
}
|
||||
|
||||
export interface UserKnowledgeFilePacket {
|
||||
user_files: FileDescriptor[];
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
FileChatDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
MultiModelMessageResponseIDInfo,
|
||||
ResearchType,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
@@ -97,7 +96,6 @@ export type PacketType =
|
||||
| FileChatDisplay
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| UserKnowledgeFilePacket
|
||||
| Packet;
|
||||
@@ -111,13 +109,6 @@ export type MessageOrigin =
|
||||
| "slackbot"
|
||||
| "unknown";
|
||||
|
||||
export interface LLMOverride {
|
||||
model_provider: string;
|
||||
model_version: string;
|
||||
temperature?: number;
|
||||
display_name?: string;
|
||||
}
|
||||
|
||||
export interface SendMessageParams {
|
||||
message: string;
|
||||
fileDescriptors?: FileDescriptor[];
|
||||
@@ -133,8 +124,6 @@ export interface SendMessageParams {
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// Multi-model: send multiple LLM overrides for parallel generation
|
||||
llmOverrides?: LLMOverride[];
|
||||
// Origin of the message for telemetry tracking
|
||||
origin?: MessageOrigin;
|
||||
// Additional context injected into the LLM call but not stored/shown in chat.
|
||||
@@ -155,7 +144,6 @@ export async function* sendMessage({
|
||||
modelProvider,
|
||||
modelVersion,
|
||||
temperature,
|
||||
llmOverrides,
|
||||
origin,
|
||||
additionalContext,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
@@ -177,8 +165,6 @@ export async function* sendMessage({
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
// Multi-model: list of LLM overrides for parallel generation
|
||||
llm_overrides: llmOverrides ?? null,
|
||||
// Default to "unknown" for consistency with backend; callers should set explicitly
|
||||
origin: origin ?? "unknown",
|
||||
additional_context: additionalContext ?? null,
|
||||
@@ -202,20 +188,6 @@ export async function* sendMessage({
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
}
|
||||
|
||||
export async function setPreferredResponse(
|
||||
userMessageId: number,
|
||||
preferredResponseId: number
|
||||
): Promise<Response> {
|
||||
return fetch("/api/chat/set-preferred-response", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
user_message_id: userMessageId,
|
||||
preferred_response_id: preferredResponseId,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: string) {
|
||||
const response = await fetch("/api/chat/rename-chat-session", {
|
||||
method: "PUT",
|
||||
@@ -385,9 +357,6 @@ export function processRawChatHistory(
|
||||
overridden_model: messageInfo.overridden_model,
|
||||
packets: packetsForMessage || [],
|
||||
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
|
||||
// Multi-model answer generation
|
||||
preferredResponseId: messageInfo.preferred_response_id ?? null,
|
||||
modelDisplayName: messageInfo.model_display_name ?? null,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
|
||||
@@ -403,7 +403,6 @@ export interface Placement {
|
||||
turn_index: number;
|
||||
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
|
||||
sub_turn_index?: number | null;
|
||||
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
|
||||
}
|
||||
|
||||
// Packet wrapper for streaming objects
|
||||
|
||||
@@ -13,6 +13,7 @@ interface OrientationLayoutProps {
|
||||
nonInteractive?: boolean;
|
||||
children?: React.ReactNode;
|
||||
title: string;
|
||||
titleSuffix?: string;
|
||||
description?: string;
|
||||
optional?: boolean;
|
||||
sizePreset?: "main-content" | "main-ui";
|
||||
@@ -50,6 +51,7 @@ function VerticalInputLayout({
|
||||
children,
|
||||
subDescription,
|
||||
title,
|
||||
titleSuffix,
|
||||
description,
|
||||
optional,
|
||||
sizePreset = "main-content",
|
||||
@@ -58,6 +60,7 @@ function VerticalInputLayout({
|
||||
<Section gap={0.25} alignItems="start">
|
||||
<Content
|
||||
title={title}
|
||||
titleSuffix={titleSuffix}
|
||||
description={description}
|
||||
optional={optional}
|
||||
sizePreset={sizePreset}
|
||||
@@ -125,6 +128,7 @@ function HorizontalInputLayout({
|
||||
children,
|
||||
center,
|
||||
title,
|
||||
titleSuffix,
|
||||
description,
|
||||
optional,
|
||||
sizePreset = "main-content",
|
||||
@@ -139,6 +143,7 @@ function HorizontalInputLayout({
|
||||
<div className="flex flex-col flex-1 min-w-0 self-stretch">
|
||||
<Content
|
||||
title={title}
|
||||
titleSuffix={titleSuffix}
|
||||
description={description}
|
||||
optional={optional}
|
||||
sizePreset={sizePreset}
|
||||
|
||||
@@ -839,6 +839,40 @@ export const connectorConfigs: Record<
|
||||
description:
|
||||
"Index aspx-pages of all SharePoint sites defined above, even if a library or folder is specified.",
|
||||
},
|
||||
{
|
||||
type: "checkbox",
|
||||
label: "Treat sharing links as public?",
|
||||
description:
|
||||
"When enabled, documents with a sharing link (anonymous or organization-wide) " +
|
||||
"are treated as public (visible to all Onyx users). " +
|
||||
"When disabled, only users and groups with explicit role assignments can see the document.",
|
||||
name: "treat_sharing_link_as_public",
|
||||
optional: true,
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
type: "list",
|
||||
query: "Enter site URLs to exclude:",
|
||||
label: "Excluded Sites",
|
||||
name: "excluded_sites",
|
||||
optional: true,
|
||||
description:
|
||||
"Site URLs or glob patterns to exclude from indexing. " +
|
||||
"Matched sites will never be indexed, even if they appear in the sites list above. " +
|
||||
"Examples: 'https://contoso.sharepoint.com/sites/archive' (exact), " +
|
||||
"'*://*/sites/archive-*' (glob pattern).",
|
||||
},
|
||||
{
|
||||
type: "list",
|
||||
query: "Enter file path patterns to exclude:",
|
||||
label: "Excluded Paths",
|
||||
name: "excluded_paths",
|
||||
optional: true,
|
||||
description:
|
||||
"Glob patterns for file paths to exclude from indexing within document libraries. " +
|
||||
"Patterns are matched against both the full relative path and the filename. " +
|
||||
"Examples: '*.tmp' (temp files), '~$*' (Office lock files), 'Archive/*' (folder).",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
query: "Microsoft Authority Host:",
|
||||
@@ -1941,6 +1975,7 @@ export interface SalesforceConfig {
|
||||
export interface SharepointConfig {
|
||||
sites?: string[];
|
||||
include_site_pages?: boolean;
|
||||
treat_sharing_link_as_public?: boolean;
|
||||
include_site_documents?: boolean;
|
||||
authority_host?: string;
|
||||
graph_api_host?: string;
|
||||
|
||||
@@ -35,6 +35,20 @@ export function getXYearsAgo(yearsAgo: number) {
|
||||
return yearsAgoDate;
|
||||
}
|
||||
|
||||
export function normalizeDate(date: Date): Date {
|
||||
const normalizedDate = new Date(date);
|
||||
normalizedDate.setHours(0, 0, 0, 0);
|
||||
return normalizedDate;
|
||||
}
|
||||
|
||||
export function isAfterDate(date: Date, maxDate: Date): boolean {
|
||||
return normalizeDate(date).getTime() > normalizeDate(maxDate).getTime();
|
||||
}
|
||||
|
||||
export function isDateInFuture(date: Date): boolean {
|
||||
return isAfterDate(date, new Date());
|
||||
}
|
||||
|
||||
export const timestampToDateString = (timestamp: string) => {
|
||||
const date = new Date(timestamp);
|
||||
const year = date.getFullYear();
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { isAfterDate, normalizeDate } from "@/lib/dateUtils";
|
||||
import Calendar from "@/refresh-components/Calendar";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { useState } from "react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { SvgCalendar } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
@@ -15,29 +16,59 @@ export interface InputDatePickerProps {
|
||||
setSelectedDate: (date: Date | null) => void;
|
||||
startYear?: number;
|
||||
disabled?: boolean;
|
||||
maxDate?: Date;
|
||||
}
|
||||
|
||||
function extractYear(date: Date | null): number {
|
||||
return (date ?? new Date()).getFullYear();
|
||||
}
|
||||
|
||||
function clampToMaxDate(date: Date, maxDate?: Date): Date {
|
||||
if (!maxDate || !isAfterDate(date, maxDate)) {
|
||||
return date;
|
||||
}
|
||||
|
||||
return normalizeDate(maxDate);
|
||||
}
|
||||
|
||||
export default function InputDatePicker({
|
||||
name,
|
||||
selectedDate,
|
||||
setSelectedDate,
|
||||
startYear = 1970,
|
||||
disabled = false,
|
||||
maxDate,
|
||||
}: InputDatePickerProps) {
|
||||
const validStartYear = Math.max(startYear, 1970);
|
||||
const currYear = extractYear(new Date());
|
||||
const years = Array(currYear - validStartYear + 1)
|
||||
.fill(currYear)
|
||||
.map((currYear, index) => currYear - index);
|
||||
const normalizedMaxDate = useMemo(
|
||||
() => (maxDate ? normalizeDate(maxDate) : undefined),
|
||||
[maxDate]
|
||||
);
|
||||
const currYear = Math.max(
|
||||
validStartYear,
|
||||
extractYear(normalizedMaxDate ?? new Date())
|
||||
);
|
||||
const years = useMemo(
|
||||
() =>
|
||||
Array(currYear - validStartYear + 1)
|
||||
.fill(currYear)
|
||||
.map((year, index) => year - index),
|
||||
[currYear, validStartYear]
|
||||
);
|
||||
const [open, setOpen] = useState(false);
|
||||
const [displayedMonth, setDisplayedMonth] = useState<Date>(
|
||||
selectedDate ?? new Date()
|
||||
clampToMaxDate(
|
||||
selectedDate ?? normalizedMaxDate ?? new Date(),
|
||||
normalizedMaxDate
|
||||
)
|
||||
);
|
||||
|
||||
function handleDateSelection(date: Date) {
|
||||
setSelectedDate(date);
|
||||
setDisplayedMonth(date);
|
||||
setOpen(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild id={name} name={name}>
|
||||
@@ -68,7 +99,7 @@ export default function InputDatePicker({
|
||||
</InputSelect>
|
||||
<Button
|
||||
onClick={() => {
|
||||
const now = new Date();
|
||||
const now = normalizedMaxDate ?? new Date();
|
||||
setSelectedDate(now);
|
||||
setDisplayedMonth(now);
|
||||
setOpen(false);
|
||||
@@ -82,14 +113,16 @@ export default function InputDatePicker({
|
||||
selected={selectedDate ?? undefined}
|
||||
onSelect={(date) => {
|
||||
if (date) {
|
||||
setSelectedDate(date);
|
||||
setOpen(false);
|
||||
handleDateSelection(date);
|
||||
}
|
||||
}}
|
||||
month={displayedMonth}
|
||||
onMonthChange={setDisplayedMonth}
|
||||
fromDate={new Date(validStartYear, 0)}
|
||||
toDate={new Date()}
|
||||
disabled={
|
||||
normalizedMaxDate ? [{ after: normalizedMaxDate }] : undefined
|
||||
}
|
||||
startMonth={new Date(validStartYear, 0)}
|
||||
endMonth={normalizedMaxDate ?? new Date()}
|
||||
showOutsideDays={false}
|
||||
/>
|
||||
</Section>
|
||||
|
||||
@@ -83,6 +83,7 @@ import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import EnabledCount from "@/refresh-components/EnabledCount";
|
||||
import { useAppRouter } from "@/hooks/appNavigation";
|
||||
import { isDateInFuture } from "@/lib/dateUtils";
|
||||
import {
|
||||
deleteAgent,
|
||||
updateAgentFeaturedStatus,
|
||||
@@ -699,7 +700,14 @@ export default function AgentEditorPage({
|
||||
// Advanced
|
||||
llm_model_provider_override: Yup.string().nullable().optional(),
|
||||
llm_model_version_override: Yup.string().nullable().optional(),
|
||||
knowledge_cutoff_date: Yup.date().nullable().optional(),
|
||||
knowledge_cutoff_date: Yup.date()
|
||||
.nullable()
|
||||
.optional()
|
||||
.test(
|
||||
"knowledge-cutoff-date-not-in-future",
|
||||
"Knowledge cutoff date must be today or earlier.",
|
||||
(value) => !value || !isDateInFuture(value)
|
||||
),
|
||||
replace_base_system_prompt: Yup.boolean(),
|
||||
reminders: Yup.string().optional(),
|
||||
|
||||
@@ -1521,7 +1529,7 @@ export default function AgentEditorPage({
|
||||
<InputLayouts.Horizontal
|
||||
name="llm_model"
|
||||
title="Default Model"
|
||||
description="Select the LLM model to use for this agent. If not set, the user's default model will be used."
|
||||
description="This model will be used by Onyx by default in your chats."
|
||||
>
|
||||
<LLMSelector
|
||||
name="llm_model"
|
||||
@@ -1538,14 +1546,19 @@ export default function AgentEditorPage({
|
||||
<InputLayouts.Horizontal
|
||||
name="knowledge_cutoff_date"
|
||||
title="Knowledge Cutoff Date"
|
||||
description="Set the knowledge cutoff date for this agent. The agent will only use information up to this date."
|
||||
optional
|
||||
description="Documents with a last-updated date prior to this will be ignored."
|
||||
>
|
||||
<InputDatePickerField name="knowledge_cutoff_date" />
|
||||
<InputDatePickerField
|
||||
name="knowledge_cutoff_date"
|
||||
maxDate={new Date()}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
<InputLayouts.Horizontal
|
||||
name="replace_base_system_prompt"
|
||||
title="Overwrite System Prompt"
|
||||
description='Completely replace the base system prompt. This might affect response quality since it will also overwrite useful system instructions (e.g. "You (the LLM) can provide markdown and it will be rendered").'
|
||||
titleSuffix="(Not Recommended)"
|
||||
description='Remove the base system prompt which includes useful instructions (e.g. "You can use Markdown tables"). This may affect response quality.'
|
||||
>
|
||||
<SwitchField name="replace_base_system_prompt" />
|
||||
</InputLayouts.Horizontal>
|
||||
@@ -1555,6 +1568,7 @@ export default function AgentEditorPage({
|
||||
<InputLayouts.Vertical
|
||||
name="reminders"
|
||||
title="Reminders"
|
||||
optional
|
||||
>
|
||||
<InputTextAreaField
|
||||
name="reminders"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import type { Route } from "next";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import { SvgChevronRight, SvgUserManage, SvgUsers } from "@opal/icons";
|
||||
@@ -65,7 +66,7 @@ function GroupCard({ group }: GroupCardProps) {
|
||||
prominence="tertiary"
|
||||
tooltip="View group"
|
||||
aria-label="View group"
|
||||
onClick={() => router.push(`/admin/groups/${group.id}`)}
|
||||
onClick={() => router.push(`/admin/groups/${group.id}` as Route)}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import type { Route } from "next";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR from "swr";
|
||||
@@ -47,7 +48,7 @@ function GroupsPage() {
|
||||
/>
|
||||
<Button
|
||||
icon={SvgPlusCircle}
|
||||
onClick={() => router.push("/admin/groups/create")}
|
||||
onClick={() => router.push("/admin/groups/create" as Route)}
|
||||
>
|
||||
New Group
|
||||
</Button>
|
||||
|
||||
@@ -6,188 +6,64 @@ import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
test.describe.configure({ mode: "parallel" });
|
||||
|
||||
interface AdminPageSnapshot {
|
||||
name: string;
|
||||
path: string;
|
||||
pageTitle: string;
|
||||
options?: {
|
||||
paragraphText?: string | RegExp;
|
||||
buttonName?: string;
|
||||
subHeaderText?: string;
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Discover all navigable admin pages by collecting links from the sidebar.
|
||||
* The sidebar is rendered on every `/admin/*` page, so we visit one admin
|
||||
* route and scrape the `<a>` elements that are present for the current
|
||||
* user / feature-flag configuration.
|
||||
*/
|
||||
async function discoverAdminPages(page: Page): Promise<string[]> {
|
||||
await page.goto("/admin/configuration/llm");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
{
|
||||
name: "Document Management - Explorer",
|
||||
path: "documents/explorer",
|
||||
pageTitle: "Document Explorer",
|
||||
},
|
||||
{
|
||||
name: "Connectors - Add Connector",
|
||||
path: "add-connector",
|
||||
pageTitle: "Add Connector",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - Agents",
|
||||
path: "agents",
|
||||
pageTitle: "Agents",
|
||||
},
|
||||
{
|
||||
name: "Configuration - Document Processing",
|
||||
path: "configuration/document-processing",
|
||||
pageTitle: "Document Processing",
|
||||
},
|
||||
{
|
||||
name: "Document Management - Document Sets",
|
||||
path: "documents/sets",
|
||||
pageTitle: "Document Sets",
|
||||
options: {
|
||||
paragraphText:
|
||||
"Document Sets allow you to group logically connected documents into a single bundle. These can then be used as a filter when performing searches to control the scope of information Onyx searches over.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Integrations - Slack Integration",
|
||||
path: "bots",
|
||||
pageTitle: "Slack Integration",
|
||||
options: {
|
||||
paragraphText:
|
||||
"Setup Slack bots that connect to Onyx. Once setup, you will be able to ask questions to Onyx directly from Slack. Additionally, you can:",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - Standard Answers",
|
||||
path: "standard-answer",
|
||||
pageTitle: "Standard Answers",
|
||||
},
|
||||
{
|
||||
name: "Performance - Usage Statistics",
|
||||
path: "performance/usage",
|
||||
pageTitle: "Usage Statistics",
|
||||
},
|
||||
{
|
||||
name: "Document Management - Feedback",
|
||||
path: "documents/feedback",
|
||||
pageTitle: "Document Feedback",
|
||||
},
|
||||
{
|
||||
name: "Configuration - LLM",
|
||||
path: "configuration/llm",
|
||||
pageTitle: "Language Models",
|
||||
},
|
||||
{
|
||||
name: "Connectors - Existing Connectors",
|
||||
path: "indexing/status",
|
||||
pageTitle: "Existing Connectors",
|
||||
},
|
||||
{
|
||||
name: "User Management - Groups",
|
||||
path: "groups",
|
||||
pageTitle: "Groups",
|
||||
},
|
||||
{
|
||||
name: "Appearance & Theming",
|
||||
path: "theme",
|
||||
pageTitle: "Appearance & Theming",
|
||||
},
|
||||
{
|
||||
name: "Documents & Knowledge - Index Settings",
|
||||
path: "configuration/search",
|
||||
pageTitle: "Index Settings",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - MCP Actions",
|
||||
path: "actions/mcp",
|
||||
pageTitle: "MCP Actions",
|
||||
},
|
||||
{
|
||||
name: "Custom Agents - OpenAPI Actions",
|
||||
path: "actions/open-api",
|
||||
pageTitle: "OpenAPI Actions",
|
||||
},
|
||||
{
|
||||
name: "Organization - Spending Limits",
|
||||
path: "token-rate-limits",
|
||||
pageTitle: "Spending Limits",
|
||||
options: {
|
||||
paragraphText:
|
||||
"Token rate limits enable you control how many tokens can be spent in a given time period. With token rate limits, you can:",
|
||||
buttonName: "Create a Token Rate Limit",
|
||||
},
|
||||
},
|
||||
];
|
||||
return page.evaluate(() => {
|
||||
const sidebar = document.querySelector('[class*="group/SidebarWrapper"]');
|
||||
if (!sidebar) return [];
|
||||
|
||||
async function verifyAdminPageNavigation(
|
||||
page: Page,
|
||||
path: string,
|
||||
pageTitle: string,
|
||||
options?: {
|
||||
paragraphText?: string | RegExp;
|
||||
buttonName?: string;
|
||||
subHeaderText?: string;
|
||||
}
|
||||
) {
|
||||
await page.goto(`/admin/${path}`);
|
||||
|
||||
try {
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
new RegExp(`^${pageTitle}`),
|
||||
{
|
||||
timeout: 10000,
|
||||
}
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to find admin-page title with text "${pageTitle}" for path "${path}"`
|
||||
);
|
||||
// NOTE: This is a temporary measure for debugging the issue
|
||||
console.error(await page.content());
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (options?.paragraphText) {
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
options.paragraphText
|
||||
);
|
||||
}
|
||||
|
||||
if (options?.buttonName) {
|
||||
await expect(
|
||||
page.getByRole("button", { name: options.buttonName })
|
||||
).toHaveCount(1);
|
||||
}
|
||||
const hrefs = new Set<string>();
|
||||
sidebar
|
||||
.querySelectorAll<HTMLAnchorElement>('a[href^="/admin/"]')
|
||||
.forEach((a) => hrefs.add(a.getAttribute("href")!));
|
||||
return Array.from(hrefs);
|
||||
});
|
||||
}
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Admin pages (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
test(`Admin pages – ${theme} mode`, async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
|
||||
for (const snapshot of ADMIN_PAGES) {
|
||||
test(`Admin - ${snapshot.name}`, async ({ page }) => {
|
||||
await verifyAdminPageNavigation(
|
||||
page,
|
||||
snapshot.path,
|
||||
snapshot.pageTitle,
|
||||
snapshot.options
|
||||
);
|
||||
const adminHrefs = await discoverAdminPages(page);
|
||||
expect(
|
||||
adminHrefs.length,
|
||||
"Expected to discover at least one admin page from the sidebar"
|
||||
).toBeGreaterThan(0);
|
||||
|
||||
// Wait for all network requests to settle before capturing the screenshot.
|
||||
await page.waitForLoadState("networkidle");
|
||||
for (const href of adminHrefs) {
|
||||
const slug = href.replace(/^\/admin\//, "").replace(/\//g, "--");
|
||||
|
||||
// Capture a screenshot for visual regression review.
|
||||
// The screenshot name includes the theme to keep light/dark baselines separate.
|
||||
const screenshotName = `admin-${theme}-${snapshot.path.replace(
|
||||
/\//g,
|
||||
"-"
|
||||
)}`;
|
||||
await expectScreenshot(page, {
|
||||
name: screenshotName,
|
||||
mask: ['[data-testid="admin-date-range-selector-button"]'],
|
||||
});
|
||||
});
|
||||
await test.step(
|
||||
slug,
|
||||
async () => {
|
||||
await page.goto(href);
|
||||
|
||||
try {
|
||||
await expect(
|
||||
page.locator('[aria-label="admin-page-title"]')
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
} catch (error) {
|
||||
console.error(`Failed to find admin-page-title for "${href}"`);
|
||||
throw error;
|
||||
}
|
||||
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await expectScreenshot(page, {
|
||||
name: `admin-${theme}-${slug}`,
|
||||
mask: ['[data-testid="admin-date-range-selector-button"]'],
|
||||
});
|
||||
},
|
||||
{ box: true }
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
6
widget/package-lock.json
generated
6
widget/package-lock.json
generated
@@ -1016,9 +1016,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
Reference in New Issue
Block a user