mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-30 03:52:42 +00:00
Compare commits
12 Commits
ref-file-t
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2ef7f2f83 | ||
|
|
cd561b21a2 | ||
|
|
7ca31368bb | ||
|
|
6dcbde2c03 | ||
|
|
9df6b6183a | ||
|
|
8fb7cd6189 | ||
|
|
2724c61c95 | ||
|
|
a23ee85039 | ||
|
|
d4d0f3c612 | ||
|
|
8e1ad517e9 | ||
|
|
4a9c8b6fbf | ||
|
|
a49edf3e18 |
@@ -1,64 +0,0 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
[
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/AGENTS.md",
|
||||
"description": "Frontend coding standards for the web directory"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/tests/README.md",
|
||||
"description": "Frontend testing guide and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/CLAUDE.md",
|
||||
"description": "Single source of truth for frontend coding standards"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/lib/opal/README.md",
|
||||
"description": "Opal component library usage guide"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**"],
|
||||
"path": "backend/tests/README.md",
|
||||
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/onyx/connectors/**"],
|
||||
"path": "backend/onyx/connectors/README.md",
|
||||
"description": "Connector development guide covering design, interfaces, and required changes"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
@@ -1,39 +0,0 @@
|
||||
# Greptile Review Rules
|
||||
|
||||
## Type Annotations
|
||||
|
||||
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
|
||||
|
||||
## Debugging Code
|
||||
|
||||
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
|
||||
|
||||
## Hardcoded Booleans
|
||||
|
||||
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
|
||||
|
||||
## Multi-tenant vs Single-tenant
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
@@ -1,35 +0,0 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -473,8 +473,6 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -106,11 +105,9 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -183,15 +180,9 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector=slack_connector,
|
||||
slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -41,19 +40,10 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,84 @@
|
||||
import queue
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets produced during tool and LLM execution to the right destination.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Operates in one of two modes determined by whether ``merged_queue`` is supplied:
|
||||
|
||||
**Standalone** (no ``merged_queue``): packets land on ``self.bus``. Used by tests,
|
||||
custom tools, and any caller that reads the emitter directly after execution.
|
||||
|
||||
**Streaming** (``merged_queue`` provided): packets are tagged with ``model_index``
|
||||
and placed as ``(key, packet)`` tuples on the shared queue for the
|
||||
``_run_models`` drain loop to consume and yield downstream.
|
||||
|
||||
Attributes:
|
||||
bus: Fallback queue for standalone mode. Always created so existing callers
|
||||
(tests, eval harnesses, custom-tool scripts) work without modification.
|
||||
|
||||
Args:
|
||||
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
|
||||
runs to preserve the backwards-compatible wire format (``model_index=None``
|
||||
in the packet); pass an integer for each model in a multi-model run.
|
||||
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
|
||||
all ``emit()`` calls route here instead of ``self.bus``.
|
||||
|
||||
Example::
|
||||
|
||||
# Standalone — read from bus after the fact (tests, evals)
|
||||
emitter = Emitter()
|
||||
emitter.emit(packet)
|
||||
result = emitter.bus.get()
|
||||
|
||||
# Streaming — wired into _run_models (production path)
|
||||
emitter = Emitter(model_idx=0, merged_queue=merged_queue)
|
||||
emitter.emit(packet) # places (0, tagged_packet) on merged_queue
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_idx: int | None = None,
|
||||
merged_queue: "queue.Queue | None" = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
# Always created for backwards compatibility (tests, custom_tool, customer scripts, etc.)
|
||||
self.bus: Queue[Packet] = Queue()
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
"""Emit a packet, routing it to the merged queue or the local bus.
|
||||
|
||||
In streaming mode, stamps the packet's placement with ``model_index`` before
|
||||
forwarding so the drain loop can attribute it to the correct model. In
|
||||
standalone mode, places the packet on ``self.bus`` unchanged.
|
||||
|
||||
Args:
|
||||
packet: The packet to emit.
|
||||
"""
|
||||
if self._merged_queue is not None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
key = self._model_idx if self._model_idx is not None else 0
|
||||
try:
|
||||
self._merged_queue.put((key, tagged_packet), timeout=1.0)
|
||||
except queue.Full:
|
||||
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
|
||||
pass
|
||||
else:
|
||||
self.bus.put(packet)
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
return Emitter()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,31 +44,6 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -86,6 +61,17 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Requests timeout in seconds.
|
||||
_CANVAS_CALL_TIMEOUT: int = 30
|
||||
_CANVAS_API_VERSION: str = "/api/v1"
|
||||
# Matches the "next" URL in a Canvas Link header, e.g.:
|
||||
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
|
||||
# Captures the URL inside the angle brackets.
|
||||
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
|
||||
|
||||
|
||||
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
return OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
class CanvasApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
canvas_base_url: str,
|
||||
) -> None:
|
||||
parsed_base = urlparse(canvas_base_url)
|
||||
if not parsed_base.hostname:
|
||||
raise ValueError("canvas_base_url must include a valid host")
|
||||
if parsed_base.scheme != "https":
|
||||
raise ValueError("canvas_base_url must use https")
|
||||
|
||||
self._bearer_token = bearer_token
|
||||
self.base_url = (
|
||||
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
|
||||
+ _CANVAS_API_VERSION
|
||||
)
|
||||
# Hostname is already validated above; reuse parsed_base instead
|
||||
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
|
||||
self._expected_host: str = parsed_base.hostname
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str = "",
|
||||
params: dict[str, Any] | None = None,
|
||||
full_url: str | None = None,
|
||||
) -> tuple[Any, str | None]:
|
||||
"""Make a GET request to the Canvas API.
|
||||
|
||||
Returns a tuple of (json_body, next_url).
|
||||
next_url is parsed from the Link header and is None if there are no more pages.
|
||||
If full_url is provided, it is used directly (for following pagination links).
|
||||
|
||||
Security note: full_url must only be set to values returned by
|
||||
``_parse_next_link``, which validates the host against the configured
|
||||
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
|
||||
"""
|
||||
# full_url is used when following pagination (Canvas returns the
|
||||
# next-page URL in the Link header). For the first request we build
|
||||
# the URL from the endpoint name instead.
|
||||
url = full_url if full_url else self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params if not full_url else None,
|
||||
timeout=_CANVAS_CALL_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except ValueError as e:
|
||||
if response.status_code < 300:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=f"Invalid JSON in Canvas response: {e}",
|
||||
)
|
||||
logger.warning(
|
||||
"Failed to parse JSON from Canvas error response (status=%d): %s",
|
||||
response.status_code,
|
||||
e,
|
||||
)
|
||||
response_json = {}
|
||||
|
||||
if response.status_code >= 400:
|
||||
# Try to extract the most specific error message from the
|
||||
# Canvas response body. Canvas uses three different shapes
|
||||
# depending on the endpoint and error type:
|
||||
default_error: str = response.reason or f"HTTP {response.status_code}"
|
||||
error = default_error
|
||||
if isinstance(response_json, dict):
|
||||
# Shape 1: {"error": {"message": "Not authorized"}}
|
||||
error_field = response_json.get("error")
|
||||
if isinstance(error_field, dict):
|
||||
response_error = error_field.get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
# Shape 2: {"error": "Invalid access token"}
|
||||
elif isinstance(error_field, str):
|
||||
error = error_field
|
||||
# Shape 3: {"errors": [{"message": "..."}]}
|
||||
# Used for validation errors. Only use as fallback if
|
||||
# we didn't already find a more specific message above.
|
||||
if error == default_error:
|
||||
errors_list = response_json.get("errors")
|
||||
if isinstance(errors_list, list) and errors_list:
|
||||
first_error = errors_list[0]
|
||||
if isinstance(first_error, dict):
|
||||
msg = first_error.get("message", "")
|
||||
if msg:
|
||||
error = msg
|
||||
raise OnyxError(
|
||||
_error_code_for_status(response.status_code),
|
||||
detail=error,
|
||||
status_code_override=response.status_code,
|
||||
)
|
||||
|
||||
next_url = self._parse_next_link(response.headers.get("Link", ""))
|
||||
return response_json, next_url
|
||||
|
||||
def _parse_next_link(self, link_header: str) -> str | None:
|
||||
"""Extract the 'next' URL from a Canvas Link header.
|
||||
|
||||
Only returns URLs whose host matches the configured Canvas base URL
|
||||
to prevent leaking the bearer token to arbitrary hosts.
|
||||
"""
|
||||
expected_host = self._expected_host
|
||||
for match in _NEXT_LINK_PATTERN.finditer(link_header):
|
||||
url = match.group(1)
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname != expected_host:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination returned an unexpected host "
|
||||
f"({parsed_url.hostname}); expected {expected_host}"
|
||||
),
|
||||
)
|
||||
if parsed_url.scheme != "https":
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination link must use https, "
|
||||
f"got {parsed_url.scheme!r}"
|
||||
),
|
||||
)
|
||||
return url
|
||||
return None
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Return the Authorization header with the bearer token."""
|
||||
return {"Authorization": f"Bearer {self._bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""Build a full Canvas API URL from an endpoint path.
|
||||
|
||||
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
|
||||
Only called on a first request, endpoint must be set for first request.
|
||||
Verify endpoint exists in case of future changes where endpoint might be optional.
|
||||
Leading slashes are stripped to avoid double-slash in the result.
|
||||
self.base_url is already normalized with no trailing slash.
|
||||
"""
|
||||
final_url = self.base_url
|
||||
clean_endpoint = endpoint.lstrip("/")
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
@@ -1,74 +0,0 @@
|
||||
from typing import Literal
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
page_id: int
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
message: str | None = None
|
||||
html_url: str
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
Fields:
|
||||
course_ids: Materialized list of course IDs to process.
|
||||
current_course_index: Index into course_ids for current course.
|
||||
stage: Which item type we're processing for the current course.
|
||||
next_url: Pagination cursor within the current stage. None means
|
||||
start from the first page; a URL means resume from that page.
|
||||
|
||||
Invariant:
|
||||
If current_course_index is incremented, stage must be reset to
|
||||
"pages" and next_url must be reset to None.
|
||||
"""
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,9 +950,7 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -10,7 +10,6 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -240,53 +239,29 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
raise e
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in raw_issues
|
||||
for issue in response["issues"]
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1765,11 +1765,7 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
@@ -1790,9 +1786,7 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
site_descriptor=site_descriptor
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
@@ -1847,9 +1841,7 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -2573,22 +2565,12 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -516,8 +516,6 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -548,8 +546,6 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -851,8 +847,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -865,8 +861,6 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -617,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -3135,6 +3135,8 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
@@ -35,19 +34,9 @@ class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Filenames that should be stored but not indexed.
|
||||
skip_indexing_filenames: set[str] = Field(default_factory=set)
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def indexable_files(self) -> list[UserFile]:
|
||||
return [
|
||||
uf
|
||||
for uf in self.user_files
|
||||
if (uf.name or "") not in self.skip_indexing_filenames
|
||||
]
|
||||
|
||||
|
||||
def build_hashed_file_key(file: UploadFile) -> str:
|
||||
name_prefix = (file.filename or "")[:50]
|
||||
@@ -109,7 +98,6 @@ def create_user_files(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files.skip_indexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +123,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
indexable_files = categorized_files_result.indexable_files
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
@@ -147,12 +134,12 @@ def upload_files_to_user_files_with_indexing(
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
@@ -168,7 +155,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
@@ -108,10 +119,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Delete a voice provider by ID."""
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
db_session.delete(provider)
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
@@ -35,12 +34,13 @@ class OnyxMimeTypes:
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
|
||||
@@ -13,14 +13,13 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
# Tabular data files (CSV, TSV, XLSX) — metadata-only injection
|
||||
TABULAR = "tabular"
|
||||
CSV = "csv"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.TABULAR,
|
||||
ChatFileType.CSV,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -174,10 +173,8 @@ class UserFileIndexingAdapter:
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,6 @@ class LlmProviderNames(str, Enum):
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -45,7 +44,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
]
|
||||
|
||||
|
||||
@@ -63,7 +61,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -115,7 +112,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -290,17 +290,6 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
|
||||
model_kwargs["api_base"] = api_base
|
||||
@@ -412,20 +401,14 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
model = (
|
||||
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
)
|
||||
|
||||
# Tool choice
|
||||
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
@@ -500,11 +483,10 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
if not (is_claude_model or is_ollama or is_mistral):
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
# routed through Bifrost's OpenAI-compatible endpoint.
|
||||
# so it must be conditionally included.
|
||||
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
|
||||
# See also, https://github.com/ollama/ollama/issues/11171
|
||||
optional_kwargs["allowed_openai_params"] = ["tool_choice"]
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -13,8 +13,6 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
@@ -50,7 +49,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -175,32 +175,6 @@ def get_tokenizer(
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
# Max characters per encode() call.
|
||||
_ENCODE_CHUNK_SIZE = 500_000
|
||||
|
||||
|
||||
def count_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
token_limit: int | None = None,
|
||||
) -> int:
|
||||
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
|
||||
|
||||
If token_limit is provided and the text is large enough to require
|
||||
multiple chunks (> 500k chars), stops early once the count exceeds it.
|
||||
When early-exiting, the returned value exceeds token_limit but may be
|
||||
less than the true full token count.
|
||||
"""
|
||||
if len(text) <= _ENCODE_CHUNK_SIZE:
|
||||
return len(tokenizer.encode(text))
|
||||
total = 0
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
if token_limit is not None and total > token_limit:
|
||||
return total # Already over — skip remaining chunks
|
||||
return total
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -142,11 +141,19 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -9,15 +9,20 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -76,26 +81,11 @@ class CategorizedFiles(BaseModel):
|
||||
acceptable: list[UploadFile] = Field(default_factory=list)
|
||||
rejected: list[RejectedFile] = Field(default_factory=list)
|
||||
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
|
||||
# Filenames within `acceptable` that should be stored but not indexed.
|
||||
skip_indexing: set[str] = Field(default_factory=set)
|
||||
|
||||
# Allow FastAPI UploadFile instances
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
# Extensions that bypass the token-count threshold on upload.
|
||||
_TOKEN_THRESHOLD_EXEMPT_EXTENSIONS: set[str] = {
|
||||
".csv",
|
||||
".tsv",
|
||||
".xlsx",
|
||||
}
|
||||
|
||||
|
||||
def _skip_token_threshold(extension: str) -> bool:
|
||||
"""Return True if this file extension should bypass the token limit."""
|
||||
return extension.lower() in _TOKEN_THRESHOLD_EXEMPT_EXTENSIONS
|
||||
|
||||
|
||||
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
|
||||
if max(width, height) <= cap:
|
||||
return width, height
|
||||
@@ -171,8 +161,8 @@ def categorize_uploaded_files(
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length exceeds the admin-configured threshold, reject file.
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -183,33 +173,36 @@ def categorize_uploaded_files(
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Derive limits from admin-configurable settings.
|
||||
# For upload size: load_settings() resolves 0/None to a positive default.
|
||||
# For token threshold: 0 means "no limit" (converted to None below).
|
||||
settings = load_settings()
|
||||
max_upload_size_mb = (
|
||||
settings.user_file_max_upload_size_mb
|
||||
) # always positive after load_settings()
|
||||
max_upload_size_bytes = (
|
||||
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
|
||||
)
|
||||
token_threshold_k = settings.file_token_count_threshold_k
|
||||
token_threshold = (
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
# Check global skip flag (works for both single-tenant and multi-tenant)
|
||||
if SKIP_USERFILE_THRESHOLD:
|
||||
skip_threshold = True
|
||||
logger.info("Skipping userfile threshold check (global setting)")
|
||||
# Check tenant-specific skip list (only applicable in multi-tenant)
|
||||
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
|
||||
try:
|
||||
current_tenant_id = get_current_tenant_id()
|
||||
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
if skip_threshold:
|
||||
logger.info(
|
||||
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap.
|
||||
if max_upload_size_bytes is not None and is_upload_too_large(
|
||||
upload, max_upload_size_bytes
|
||||
):
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -231,11 +224,11 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -276,24 +269,12 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
exceeds_threshold = (
|
||||
token_threshold is not None and token_count > token_threshold
|
||||
)
|
||||
if exceeds_threshold and _skip_token_threshold(extension):
|
||||
# Exempt extensions (e.g. spreadsheets) are accepted
|
||||
# but flagged to skip indexing — only metadata is
|
||||
# injected into the LLM context.
|
||||
results.acceptable.append(upload)
|
||||
results.acceptable_file_to_token_count[filename] = token_count
|
||||
results.skip_indexing.add(filename)
|
||||
elif exceeds_threshold:
|
||||
token_count = len(tokenizer.encode(text_content))
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -57,8 +57,6 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
@@ -1424,26 +1422,11 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="LiteLLM proxy",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_openai_compatible_models_response(
|
||||
url: str,
|
||||
source_name: str,
|
||||
api_key: str | None = None,
|
||||
) -> dict:
|
||||
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if not api_key:
|
||||
headers.pop("Authorization")
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
@@ -1453,125 +1436,20 @@ def _get_openai_compatible_models_response(
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Authentication failed: invalid or missing API key for {source_name}.",
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
|
||||
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(
|
||||
"Failed to fetch models from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Received invalid model response from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/bifrost/available-models")
|
||||
def get_bifrost_available_models(
|
||||
request: BifrostModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Bifrost endpoint",
|
||||
)
|
||||
|
||||
results: list[BifrostFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
BifrostFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Bifrost model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Bifrost",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Bifrost",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
|
||||
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -449,18 +449,3 @@ class LitellmModelDetails(BaseModel):
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
|
||||
# Bifrost dynamic models fetch
|
||||
class BifrostModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class BifrostFinalModelResponse(BaseModel):
|
||||
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -25,7 +25,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -51,25 +50,6 @@ BEDROCK_VISION_MODELS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# Known Bifrost/OpenAI-compatible vision-capable model families where the
|
||||
# source API does not expose this metadata directly.
|
||||
BIFROST_VISION_MODEL_FAMILIES = frozenset(
|
||||
{
|
||||
"anthropic/claude-3",
|
||||
"anthropic/claude-4",
|
||||
"amazon/nova-pro",
|
||||
"amazon/nova-lite",
|
||||
"amazon/nova-premier",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4.1",
|
||||
"google/gemini",
|
||||
"meta-llama/llama-3.2",
|
||||
"mistral/pixtral",
|
||||
"qwen/qwen2.5-vl",
|
||||
"qwen/qwen-vl",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_valid_bedrock_model(
|
||||
model_id: str,
|
||||
@@ -96,18 +76,11 @@ def is_valid_bedrock_model(
|
||||
def infer_vision_support(model_id: str) -> bool:
|
||||
"""Infer vision support from model ID when base model metadata unavailable.
|
||||
|
||||
Used for providers like Bedrock and Bifrost where vision support may
|
||||
need to be inferred from vendor/model naming conventions.
|
||||
Used for cross-region inference profiles when the base model isn't
|
||||
available in the user's region.
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
|
||||
return True
|
||||
|
||||
normalized_model_id = model_id_lower.replace(".", "/")
|
||||
return any(
|
||||
vision_model in normalized_model_id
|
||||
for vision_model in BIFROST_VISION_MODEL_FAMILIES
|
||||
)
|
||||
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
|
||||
|
||||
|
||||
def generate_bedrock_display_name(model_id: str) -> str:
|
||||
@@ -349,7 +322,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
- Ollama: "llama3:70b" → "Meta"
|
||||
- Ollama: "qwen2.5:7b" → "Alibaba"
|
||||
"""
|
||||
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
|
||||
if provider == LlmProviderNames.OPENROUTER:
|
||||
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
|
||||
if "/" in model_name:
|
||||
vendor_key = model_name.split("/")[0].lower()
|
||||
|
||||
@@ -449,128 +449,40 @@ class RedisHealthCollector(_CachedCollector):
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHeartbeatMonitor:
|
||||
"""Monitors Celery worker health via the event stream.
|
||||
|
||||
Subscribes to ``worker-heartbeat``, ``worker-online``, and
|
||||
``worker-offline`` events via a single persistent connection.
|
||||
Runs in a daemon thread started once during worker setup.
|
||||
"""
|
||||
|
||||
# Consider a worker down if no heartbeat received for this long.
|
||||
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
|
||||
|
||||
def __init__(self, celery_app: Any) -> None:
|
||||
self._app = celery_app
|
||||
self._worker_last_seen: dict[str, float] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background event listener thread.
|
||||
|
||||
Safe to call multiple times — only starts one thread.
|
||||
"""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._listen, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("WorkerHeartbeatMonitor started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Background loop: connect to event stream and process heartbeats."""
|
||||
while self._running:
|
||||
try:
|
||||
with self._app.connection() as conn:
|
||||
recv = self._app.events.Receiver(
|
||||
conn,
|
||||
handlers={
|
||||
"worker-heartbeat": self._on_heartbeat,
|
||||
"worker-online": self._on_heartbeat,
|
||||
"worker-offline": self._on_offline,
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.debug(
|
||||
"Heartbeat listener disconnected, reconnecting in 5s",
|
||||
exc_info=True,
|
||||
)
|
||||
time.sleep(5.0)
|
||||
else:
|
||||
# capture() returned normally (timeout with no events); reconnect
|
||||
if self._running:
|
||||
logger.debug("Heartbeat capture timed out, reconnecting")
|
||||
time.sleep(5.0)
|
||||
|
||||
def _on_heartbeat(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen[hostname] = time.monotonic()
|
||||
|
||||
def _on_offline(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen.pop(hostname, None)
|
||||
|
||||
def get_worker_status(self) -> dict[str, bool]:
|
||||
"""Return {hostname: is_alive} for all known workers.
|
||||
|
||||
Thread-safe. Called by WorkerHealthCollector on each scrape.
|
||||
Also prunes workers that have been dead longer than 2x the
|
||||
heartbeat timeout to prevent unbounded growth.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
|
||||
with self._lock:
|
||||
# Prune workers that have been gone for 2x the timeout
|
||||
stale = [
|
||||
h
|
||||
for h, ts in self._worker_last_seen.items()
|
||||
if (now - ts) > prune_threshold
|
||||
]
|
||||
for h in stale:
|
||||
del self._worker_last_seen[h]
|
||||
|
||||
result: dict[str, bool] = {}
|
||||
for hostname, last_seen in self._worker_last_seen.items():
|
||||
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
|
||||
result[hostname] = alive
|
||||
return result
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker health from the heartbeat monitor.
|
||||
"""Collects Celery worker count and process count via inspect ping.
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
|
||||
command that takes a couple seconds to complete.
|
||||
|
||||
Maintains a set of known worker short-names so that when a worker
|
||||
stops responding, we emit ``up=0`` instead of silently dropping the
|
||||
metric (which would make ``absent()``-style alerts impossible).
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._monitor: WorkerHeartbeatMonitor | None = None
|
||||
# Remove a worker from _known_workers after this many consecutive
|
||||
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
|
||||
_MAX_CONSECUTIVE_MISSES = 10
|
||||
|
||||
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
|
||||
"""Set the heartbeat monitor instance."""
|
||||
self._monitor = monitor
|
||||
def __init__(self, cache_ttl: float = 60.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
# worker short-name → consecutive miss count.
|
||||
# Workers start at 0 and reset to 0 each time they respond.
|
||||
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
|
||||
self._known_workers: dict[str, int] = {}
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app instance for inspect commands."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._monitor is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
"Number of active Celery workers responding to ping",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
@@ -579,15 +491,37 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
)
|
||||
|
||||
try:
|
||||
status = self._monitor.get_worker_status()
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
inspector = self._celery_app.control.inspect(timeout=3.0)
|
||||
ping_result = inspector.ping()
|
||||
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
responding: set[str] = set()
|
||||
if ping_result:
|
||||
active_workers.add_metric([], len(ping_result))
|
||||
for worker_name in ping_result:
|
||||
# Strip hostname suffix for cleaner labels
|
||||
short_name = worker_name.split("@")[0]
|
||||
responding.add(short_name)
|
||||
else:
|
||||
active_workers.add_metric([], 0)
|
||||
|
||||
# Register newly-seen workers and reset miss count for
|
||||
# workers that responded.
|
||||
for short_name in responding:
|
||||
self._known_workers[short_name] = 0
|
||||
|
||||
# Increment miss count for non-responding workers and evict
|
||||
# those that have been missing too long.
|
||||
stale = []
|
||||
for short_name in list(self._known_workers):
|
||||
if short_name not in responding:
|
||||
self._known_workers[short_name] += 1
|
||||
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
|
||||
stale.append(short_name)
|
||||
for short_name in stale:
|
||||
del self._known_workers[short_name]
|
||||
|
||||
for short_name in sorted(self._known_workers):
|
||||
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +28,6 @@ _attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
@@ -98,16 +96,7 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
# Module-level singleton prevents duplicate threads on re-entry.
|
||||
global _heartbeat_monitor
|
||||
if _heartbeat_monitor is None:
|
||||
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_worker_health_collector.set_celery_app(celery_app)
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
return ChatFileType.IMAGE
|
||||
|
||||
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
|
||||
return ChatFileType.TABULAR
|
||||
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
|
||||
return ChatFileType.CSV
|
||||
|
||||
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
return ChatFileType.DOC
|
||||
|
||||
@@ -2,11 +2,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to in a multi-model comparison
|
||||
(0, 1, or 2). ``None`` for single-model responses, preserving the
|
||||
backwards-compatible wire format for existing API consumers.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -9,9 +9,7 @@ from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -19,16 +17,10 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Notification
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.models import UserSettings
|
||||
@@ -49,15 +41,6 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def admin_put_settings(
|
||||
settings: Settings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb is not None
|
||||
and settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@@ -100,16 +83,6 @@ def fetch_settings(
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
),
|
||||
default_file_token_count_threshold_k=(
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,19 +2,12 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -85,12 +78,7 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = Field(
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
@@ -120,14 +108,3 @@ class UserSettings(Settings):
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
# Factory defaults so the frontend can show a "restore default" button.
|
||||
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
default_file_token_count_threshold_k: int = Field(
|
||||
default_factory=lambda: (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -57,36 +51,9 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
|
||||
# Resolve context-aware defaults for token threshold.
|
||||
# None = admin hasn't set a value yet → use context-aware default.
|
||||
# 0 = admin explicitly chose "no limit" → preserve as-is.
|
||||
if settings.file_token_count_threshold_k is None:
|
||||
settings.file_token_count_threshold_k = (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
# Upload size: 0 and None are treated as "unset" (not "no limit") →
|
||||
# fall back to min(configured default, hard ceiling).
|
||||
if not settings.user_file_max_upload_size_mb:
|
||||
settings.user_file_max_upload_size_mb = min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
)
|
||||
|
||||
# Clamp to env ceiling so stale KV values are capped even if the
|
||||
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
|
||||
# was already saved (api.py only guards new writes).
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -708,7 +708,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +743,7 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter = Emitter()
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +790,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter.bus.qsize()}")
|
||||
|
||||
@@ -169,10 +169,10 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
|
||||
chat_file = self._load_file(file_id)
|
||||
|
||||
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
|
||||
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
|
||||
# DOC type in a loaded file means plaintext extraction failed and the
|
||||
# content is the original binary (e.g. raw PDF/DOCX bytes).
|
||||
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.TABULAR):
|
||||
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
|
||||
raise ToolCallException(
|
||||
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
|
||||
llm_facing_message=(
|
||||
|
||||
@@ -191,6 +191,25 @@ IGNORED_SYNCING_TENANT_LIST = (
|
||||
else None
|
||||
)
|
||||
|
||||
# Global flag to skip userfile threshold for all users/tenants
|
||||
SKIP_USERFILE_THRESHOLD = (
|
||||
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
|
||||
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
|
||||
)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
|
||||
[
|
||||
tenant.strip()
|
||||
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
|
||||
if tenant.strip()
|
||||
]
|
||||
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -19,10 +17,6 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
# Predates any test workspace messages, so the result set should match
|
||||
# the "no start time" case while exercising the oldest= parameter.
|
||||
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@@ -111,17 +105,15 @@ def test_load_from_checkpoint_access__private_channel(
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
|
||||
def test_slim_documents_access__public_channel(
|
||||
slack_connector: SlackConnector,
|
||||
start_ts: float | None,
|
||||
) -> None:
|
||||
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
|
||||
if not slack_connector.client:
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_ts,
|
||||
start=0.0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
@@ -157,7 +149,7 @@ def test_slim_documents_access__private_channel(
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=None,
|
||||
start=0.0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
|
||||
|
||||
file_descriptor: FileDescriptor = {
|
||||
"id": user_file.file_id,
|
||||
"type": ChatFileType.TABULAR,
|
||||
"type": ChatFileType.CSV,
|
||||
"name": "data.csv",
|
||||
"user_file_id": str(user_file.id),
|
||||
}
|
||||
|
||||
253
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
253
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
Covers both modes (standalone and streaming) without any real database,
|
||||
LLM, or queue infrastructure beyond the stdlib Queue.
|
||||
"""
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone mode (no merged_queue)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterStandaloneMode:
|
||||
def test_emitted_packet_arrives_on_bus(self) -> None:
|
||||
emitter = Emitter()
|
||||
pkt = _packet()
|
||||
emitter.emit(pkt)
|
||||
assert emitter.bus.get_nowait() is pkt
|
||||
|
||||
def test_bus_is_empty_before_emit(self) -> None:
|
||||
emitter = Emitter()
|
||||
assert emitter.bus.empty()
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter = Emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
assert emitter.bus.get_nowait() is p1
|
||||
assert emitter.bus.get_nowait() is p2
|
||||
|
||||
def test_packet_not_modified(self) -> None:
|
||||
"""Standalone mode must not wrap or mutate the packet."""
|
||||
emitter = Emitter()
|
||||
pkt = _packet(turn_index=7, tab_index=3)
|
||||
emitter.emit(pkt)
|
||||
retrieved = emitter.bus.get_nowait()
|
||||
assert retrieved.placement.turn_index == 7
|
||||
assert retrieved.placement.tab_index == 3
|
||||
|
||||
def test_get_default_emitter_is_standalone(self) -> None:
|
||||
emitter = get_default_emitter()
|
||||
pkt = _packet()
|
||||
emitter.emit(pkt)
|
||||
# Packet lands on the bus, not a shared queue
|
||||
assert emitter.bus.get_nowait() is pkt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming mode (merged_queue provided)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterStreamingMode:
|
||||
# --- Queue routing ---
|
||||
|
||||
def test_packet_goes_to_merged_queue_not_bus(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
assert emitter.bus.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=1, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
# --- model_index tagging ---
|
||||
|
||||
def test_model_idx_none_preserves_model_index_none(self) -> None:
|
||||
"""N=1 backwards-compat: model_index must stay None in the packet."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index is None
|
||||
|
||||
def test_model_idx_zero_tags_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=1, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=2, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
# --- Queue key ---
|
||||
|
||||
def test_key_equals_model_idx_when_set(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=2, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_key_is_zero_when_model_idx_none(self) -> None:
|
||||
"""N=1: key defaults to 0 (single slot in the drain loop)."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
# --- Placement field preservation ---
|
||||
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
|
||||
# --- bus is always created ---
|
||||
|
||||
def test_bus_exists_in_streaming_mode(self) -> None:
|
||||
"""bus must always be present for backwards-compat with existing callers."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
assert hasattr(emitter, "bus")
|
||||
assert isinstance(emitter.bus, queue.Queue)
|
||||
|
||||
def test_bus_stays_empty_in_streaming_mode(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
assert emitter.bus.empty()
|
||||
640
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
640
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,640 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
_RUN_MODELS_PATCHES = [
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
|
||||
]
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_none(self) -> None:
|
||||
"""Single-model path: model_index stays None for wire backwards-compat."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index is None
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
emitter = kwargs["emitter"]
|
||||
# _model_idx is None for N=1, int for N>1
|
||||
if emitter._model_idx == 0:
|
||||
raise RuntimeError("model 0 failed")
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_completion_handle_called_on_disconnect(self) -> None:
|
||||
"""llm_loop_completion_handle must still be called even when user disconnects.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator, which propagates into _run_models. The finally
|
||||
block must call executor.shutdown(wait=True) to wait for LLM threads to
|
||||
finish, then persist their results via llm_loop_completion_handle.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
thread_completed = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give generator a yield point), then finish."""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
|
||||
time.sleep(0.05)
|
||||
thread_completed.set()
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
|
||||
# Finally block must have waited for the thread and saved completion.
|
||||
assert (
|
||||
thread_completed.is_set()
|
||||
), "LLM thread must complete before gen.close() returns"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called for the successful model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -139,7 +139,7 @@ def test_csv_file_type() -> None:
|
||||
result = _extract_referenced_file_descriptors([tool_call], message)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == ChatFileType.TABULAR
|
||||
assert result[0]["type"] == ChatFileType.CSV
|
||||
|
||||
|
||||
def test_unknown_extension_defaults_to_plain_text() -> None:
|
||||
|
||||
@@ -1,381 +0,0 @@
|
||||
"""Tests for Canvas connector — client (PR1)."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FAKE_BASE_URL = "https://myschool.instructure.com"
|
||||
FAKE_TOKEN = "fake-canvas-token"
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
link_header: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock HTTP response with status, json, and Link header."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.reason = "OK" if status_code < 300 else "Error"
|
||||
resp.json.return_value = json_data if json_data is not None else []
|
||||
resp.headers = {"Link": link_header}
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.__init__ tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCanvasApiClientInit:
|
||||
def test_success(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
expected_host = "myschool.instructure.com"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
assert client._expected_host == expected_host
|
||||
|
||||
def test_normalizes_trailing_slash(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=f"{FAKE_BASE_URL}/",
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
|
||||
def test_normalizes_existing_api_v1(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=f"{FAKE_BASE_URL}/api/v1",
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
|
||||
def test_rejects_non_https_scheme(self) -> None:
|
||||
with pytest.raises(ValueError, match="must use https"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="ftp://myschool.instructure.com",
|
||||
)
|
||||
|
||||
def test_rejects_http(self) -> None:
|
||||
with pytest.raises(ValueError, match="must use https"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="http://myschool.instructure.com",
|
||||
)
|
||||
|
||||
def test_rejects_missing_host(self) -> None:
|
||||
with pytest.raises(ValueError, match="must include a valid host"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="https://",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._build_url tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildUrl:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
def test_appends_endpoint(self) -> None:
|
||||
result = self.client._build_url("courses")
|
||||
expected = f"{FAKE_BASE_URL}/api/v1/courses"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_strips_leading_slash_from_endpoint(self) -> None:
|
||||
result = self.client._build_url("/courses")
|
||||
expected = f"{FAKE_BASE_URL}/api/v1/courses"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._build_headers tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildHeaders:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
def test_returns_bearer_auth(self) -> None:
|
||||
result = self.client._build_headers()
|
||||
expected = {"Authorization": f"Bearer {FAKE_TOKEN}"}
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.get tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGet:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_success_returns_json_and_next_url(self, mock_requests: MagicMock) -> None:
|
||||
next_link = f"<{FAKE_BASE_URL}/api/v1/courses?page=2>; " 'rel="next"'
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[{"id": 1}], link_header=next_link
|
||||
)
|
||||
|
||||
data, next_url = self.client.get("courses")
|
||||
|
||||
expected_data = [{"id": 1}]
|
||||
expected_next = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
|
||||
|
||||
assert data == expected_data
|
||||
assert next_url == expected_next
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_success_no_next_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[{"id": 1}])
|
||||
|
||||
data, next_url = self.client.get("courses")
|
||||
|
||||
assert data == [{"id": 1}]
|
||||
assert next_url is None
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_error_status(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(403, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_404(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(404, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(429, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
full = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
|
||||
|
||||
self.client.get(params={"per_page": "100"}, full_url=full)
|
||||
|
||||
_, kwargs = mock_requests.get.call_args
|
||||
assert kwargs["params"] is None
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_dict(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 1: {"error": {"message": "Not authorized"}}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": {"message": "Not authorized"}}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Not authorized"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_string(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 2: {"error": "Invalid access token"}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
401, {"error": "Invalid access token"}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid access token"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
400, {"errors": [{"message": "Invalid query"}]}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid query"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_dict_takes_priority_over_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""When both error shapes are present, error dict wins."""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Specific error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_falls_back_to_reason_when_no_json_message(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Empty error body falls back to response.reason."""
|
||||
mock_requests.get.return_value = _mock_response(500, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Error" # from _mock_response's reason for >= 300
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
|
||||
"""Invalid JSON on a 2xx response raises OnyxError."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
|
||||
with pytest.raises(OnyxError, match="Invalid JSON"):
|
||||
self.client.get("courses")
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_error_falls_back_to_reason(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Invalid JSON on a 4xx response falls back to response.reason."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 500
|
||||
resp.reason = "Internal Server Error"
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Internal Server Error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseNextLink:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="https://canvas.example.com",
|
||||
)
|
||||
|
||||
def test_found(self) -> None:
|
||||
header = '<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
expected = "https://canvas.example.com/api/v1/courses?page=2"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_not_found(self) -> None:
|
||||
header = '<https://canvas.example.com/api/v1/courses?page=1>; rel="current"'
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_empty(self) -> None:
|
||||
result = self.client._parse_next_link("")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_multiple_rels(self) -> None:
|
||||
header = (
|
||||
'<https://canvas.example.com/api/v1/courses?page=1>; rel="current", '
|
||||
'<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
)
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
expected = "https://canvas.example.com/api/v1/courses?page=2"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_rejects_host_mismatch(self) -> None:
|
||||
header = '<https://evil.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
with pytest.raises(OnyxError, match="unexpected host"):
|
||||
self.client._parse_next_link(header)
|
||||
|
||||
def test_rejects_non_https_link(self) -> None:
|
||||
header = '<http://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
with pytest.raises(OnyxError, match="must use https"):
|
||||
self.client._parse_next_link(header)
|
||||
@@ -1,147 +0,0 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
@@ -1,5 +1,3 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -33,7 +31,6 @@ def mock_jira_cc_pair(
|
||||
"jira_base_url": jira_base_url,
|
||||
"project_key": project_key,
|
||||
}
|
||||
mock_cc_pair.connector.indexing_start = None
|
||||
|
||||
return mock_cc_pair
|
||||
|
||||
@@ -68,75 +65,3 @@ def test_jira_permission_sync(
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
):
|
||||
print(doc)
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that generic_doc_sync derives indexing_start from cc_pair
|
||||
and forwards it to retrieve_all_slim_docs_perm_sync."""
|
||||
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
|
||||
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_none_when_no_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that indexing_start is None when the connector has no indexing_start set."""
|
||||
mock_jira_cc_pair.connector.indexing_start = None
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] is None
|
||||
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -1462,69 +1462,3 @@ def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> Non
|
||||
assert (
|
||||
"tool_choice" not in kwargs
|
||||
), "tool_choice must not be sent to providers when no tools are provided"
|
||||
|
||||
|
||||
def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
api_base="https://bifrost.example.com/",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BIFROST,
|
||||
model_name="anthropic/claude-sonnet-4-6",
|
||||
max_input_tokens=32000,
|
||||
)
|
||||
|
||||
assert llm._custom_llm_provider == "openai"
|
||||
assert llm._api_base == "https://bifrost.example.com/v1"
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
api_base="https://bifrost.example.com",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BIFROST,
|
||||
model_name="anthropic/claude-sonnet-4-6",
|
||||
max_input_tokens=32000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Use a tool if needed")]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"description": "Look up data",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Done"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="anthropic/claude-sonnet-4-6",
|
||||
),
|
||||
]
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
llm.invoke(messages, tools=tools)
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert kwargs["model"] == "anthropic/claude-sonnet-4-6"
|
||||
assert kwargs["base_url"] == "https://bifrost.example.com/v1"
|
||||
assert kwargs["custom_llm_provider"] == "openai"
|
||||
assert kwargs["allowed_openai_params"] == ["tool_choice"]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,11 +158,13 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -12,8 +12,6 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
@@ -852,15 +850,13 @@ class TestGetLitellmAvailableModels:
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = httpx.ConnectError(
|
||||
"Connection refused", request=MagicMock()
|
||||
)
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM proxy models"):
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
@@ -902,113 +898,3 @@ class TestGetLitellmAvailableModels:
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetBifrostAvailableModels:
|
||||
"""Tests for the Bifrost model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bifrost_response(self) -> dict:
|
||||
"""Mock response from Bifrost /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "anthropic/claude-3-5-sonnet",
|
||||
"name": "Claude 3.5 Sonnet",
|
||||
"context_length": 200000,
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-r1",
|
||||
"name": "DeepSeek R1",
|
||||
"context_length": 64000,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_bifrost_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted non-embedding models."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_bifrost_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
results = get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, BifrostFinalModelResponse) for r in results)
|
||||
assert [r.name for r in results] == sorted(
|
||||
[r.name for r in results], key=str.lower
|
||||
)
|
||||
|
||||
def test_infers_vision_support(self, mock_bifrost_response: dict) -> None:
|
||||
"""Test that vision support is inferred from provider/model IDs."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_bifrost_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
results = get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
claude = next(r for r in results if r.name == "anthropic/claude-3-5-sonnet")
|
||||
gpt4o = next(r for r in results if r.name == "openai/gpt-4o")
|
||||
deepseek = next(r for r in results if r.name == "deepseek/deepseek-r1")
|
||||
|
||||
assert claude.supports_image_input is True
|
||||
assert gpt4o.supports_image_input is True
|
||||
assert deepseek.supports_image_input is False
|
||||
|
||||
def test_existing_v1_suffix_is_not_duplicated(self) -> None:
|
||||
"""Test that an existing /v1 suffix still hits a single /v1/models endpoint."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {"data": [{"id": "openai/gpt-4o", "name": "GPT-4o"}]}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com/v1")
|
||||
get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert called_url == "https://bifrost.example.com/v1/models"
|
||||
|
||||
def test_request_failure_is_logged_and_wrapped(self) -> None:
|
||||
"""Test that request-layer failures are logged before raising OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.llm.api.httpx.get") as mock_get,
|
||||
patch("onyx.server.manage.llm.api.logger.warning") as mock_warning,
|
||||
):
|
||||
mock_get.side_effect = httpx.ConnectError(
|
||||
"Connection refused", request=MagicMock()
|
||||
)
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
with pytest.raises(OnyxError, match="Failed to fetch Bifrost models"):
|
||||
get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
@@ -176,14 +176,6 @@ class TestInferVisionSupport:
|
||||
"""Test Nova Pro has vision."""
|
||||
assert infer_vision_support("amazon.nova-pro-v1") is True
|
||||
|
||||
def test_bifrost_claude_has_vision(self) -> None:
|
||||
"""Test Bifrost Claude models are recognized as vision-capable."""
|
||||
assert infer_vision_support("anthropic/claude-3-5-sonnet") is True
|
||||
|
||||
def test_bifrost_gpt4o_has_vision(self) -> None:
|
||||
"""Test Bifrost GPT-4o models are recognized as vision-capable."""
|
||||
assert infer_vision_support("openai/gpt-4o") is True
|
||||
|
||||
def test_mistral_no_vision(self) -> None:
|
||||
"""Test Mistral doesn't have vision (not in known list)."""
|
||||
assert infer_vision_support("mistral.mistral-large") is False
|
||||
|
||||
@@ -4,23 +4,13 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.natural_language_processing import utils as nlp_utils
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _Tokenizer(BaseTokenizer):
|
||||
class _Tokenizer:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
@@ -39,26 +29,10 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
|
||||
return Settings(
|
||||
user_file_max_upload_size_mb=upload_size_mb,
|
||||
file_token_count_threshold_k=token_threshold_k,
|
||||
)
|
||||
|
||||
|
||||
def _patch_common_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
upload_size_mb: int = 1,
|
||||
token_threshold_k: int = 100,
|
||||
) -> None:
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
monkeypatch.setattr(
|
||||
utils,
|
||||
"load_settings",
|
||||
lambda: _make_settings(upload_size_mb, token_threshold_k),
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
@@ -102,8 +76,9 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
@@ -116,7 +91,9 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
@@ -129,11 +106,12 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
|
||||
upload = _make_upload("edge.png", size=1048576)
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
@@ -143,10 +121,12 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
|
||||
upload = _make_upload("large.png", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -157,11 +137,13 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=1048577)
|
||||
large = _make_upload("large.png", size=101)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
@@ -171,12 +153,15 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_always(
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=1048577)
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -187,12 +172,14 @@ def test_categorize_uploaded_files_enforces_size_limit_always(
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=1048577)
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
@@ -201,219 +188,40 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
|
||||
def test_categorize_uploaded_files_accepts_python_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A positive upload_size_mb is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("huge.png", size=1048577, content=b"x")
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
|
||||
|
||||
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A positive token_threshold_k is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("big_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
|
||||
|
||||
def test_categorize_no_token_limit_when_threshold_k_is_zero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
|
||||
py_source = b'def hello():\n print("world")\n'
|
||||
monkeypatch.setattr(
|
||||
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
|
||||
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
|
||||
)
|
||||
|
||||
upload = _make_upload("huge_image.png", size=100)
|
||||
upload = _make_upload("script.py", size=len(py_source), content=py_source)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 0
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "script.py"
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_both_limits_enforced(
|
||||
def test_categorize_uploaded_files_rejects_binary_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Both positive limits are enforced; file exceeding token limit is rejected."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("over_tokens.png", size=100)
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
|
||||
|
||||
binary_content = bytes(range(256)) * 4
|
||||
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 5K token limit"
|
||||
|
||||
|
||||
def test_categorize_rejection_reason_contains_dynamic_values(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
|
||||
|
||||
# File within size limit but over token limit
|
||||
upload = _make_upload("tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert result.rejected[0].reason == "Exceeds 7K token limit"
|
||||
|
||||
# File over size limit
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
|
||||
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
|
||||
|
||||
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
|
||||
|
||||
|
||||
# --- count_tokens tests ---
|
||||
|
||||
|
||||
def test_count_tokens_small_text() -> None:
|
||||
"""Small text should be encoded in a single call and return correct count."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "hello world"
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_chunked_matches_single_call() -> None:
|
||||
"""Chunked encoding should produce the same result as single-call for small text."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 1000
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
# _Tokenizer returns 1 token per char, so total should be 250
|
||||
assert count_tokens(text, tokenizer) == 250
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_exits_early(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set and exceeded, count_tokens should stop early."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
tokenizer = _CountingTokenizer()
|
||||
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer, token_limit=150)
|
||||
|
||||
assert result == 200 # 2 chunks × 100 tokens each
|
||||
assert encode_call_count == 2, "Should have stopped after 2 chunks"
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_not_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set but not exceeded, all chunks are encoded."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
result = count_tokens(text, tokenizer, token_limit=1000)
|
||||
assert result == 250
|
||||
|
||||
|
||||
def test_count_tokens_no_limit_encodes_all_chunks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Without token_limit, all chunks are encoded regardless of count."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer)
|
||||
assert result == 500
|
||||
|
||||
|
||||
# --- early exit via token_limit in categorize tests ---
|
||||
|
||||
|
||||
def test_categorize_early_exits_tokenization_for_large_text(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Large text files should be rejected via early-exit tokenization
|
||||
without encoding all chunks."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
|
||||
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
|
||||
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
|
||||
# Let's use a bigger text
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
large_text = "x" * 5000 # 5000 tokens, threshold 1000
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
|
||||
|
||||
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 1
|
||||
assert "token limit" in result.rejected[0].reason
|
||||
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
|
||||
assert (
|
||||
encode_call_count < 50
|
||||
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
|
||||
|
||||
|
||||
def test_categorize_text_under_token_limit_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Text files under the token threshold should be accepted with exact count."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
small_text = "x" * 500 # 500 tokens < 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
|
||||
|
||||
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable_file_to_token_count["ok.txt"] == 500
|
||||
assert result.rejected[0].filename == "data.bin"
|
||||
assert "Unsupported file type" in result.rejected[0].reason
|
||||
|
||||
@@ -1,23 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def __init__(self, data: dict | None = None) -> None:
|
||||
self._data = data
|
||||
|
||||
def load(self, _key: str) -> dict:
|
||||
if self._data is None:
|
||||
raise KvKeyNotFoundError()
|
||||
return self._data
|
||||
raise KvKeyNotFoundError()
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
@@ -31,140 +20,13 @@ class _FakeCache:
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_uses_model_defaults_when_no_stored_value(
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When no settings are stored (vector DB enabled), load_settings() should
|
||||
resolve the default token threshold to 200."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled and no settings are stored, the token
|
||||
threshold should default to 10000 (10M tokens)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled but admin explicitly set a token threshold,
|
||||
that value should be preserved (not overridden by the 10000 default)."""
|
||||
stored = Settings(file_token_count_threshold_k=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 500
|
||||
|
||||
|
||||
def test_load_settings_preserves_zero_token_threshold(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 means 'no limit' and should be preserved."""
|
||||
stored = Settings(file_token_count_threshold_k=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 0
|
||||
|
||||
|
||||
def test_load_settings_resolves_zero_upload_size_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
|
||||
|
||||
def test_load_settings_clamps_upload_size_to_env_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be clamped to the env-configured maximum."""
|
||||
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 250
|
||||
|
||||
|
||||
def test_load_settings_preserves_upload_size_within_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be preserved unchanged."""
|
||||
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 150
|
||||
|
||||
|
||||
def test_load_settings_zero_upload_size_resolves_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default,
|
||||
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 100
|
||||
|
||||
|
||||
def test_load_settings_default_clamped_to_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
the effective default should be min(DEFAULT, MAX)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 50
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestChatFileConversion:
|
||||
ChatLoadedFile(
|
||||
file_id="file-2",
|
||||
content=b"csv,data\n1,2",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
file_type=ChatFileType.CSV,
|
||||
filename="data.csv",
|
||||
content_text="csv,data\n1,2",
|
||||
token_count=5,
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
"""Tests for WorkerHeartbeatMonitor and WorkerHealthCollector."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
|
||||
|
||||
|
||||
class TestWorkerHeartbeatMonitor:
|
||||
def test_heartbeat_registers_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "primary@host1" in status
|
||||
assert status["primary@host1"] is True
|
||||
|
||||
def test_multiple_workers(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_heartbeat({"hostname": "docfetching@host1"})
|
||||
monitor._on_heartbeat({"hostname": "monitoring@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert len(status) == 3
|
||||
assert all(alive for alive in status.values())
|
||||
|
||||
def test_offline_removes_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_offline({"hostname": "primary@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "primary@host1" not in status
|
||||
|
||||
def test_stale_heartbeat_marks_worker_down(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["primary@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert status["primary@host1"] is False
|
||||
|
||||
def test_very_stale_worker_is_pruned(self) -> None:
|
||||
"""Workers dead for 2x the timeout are pruned from the dict."""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["gone@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS * 2 - 10
|
||||
)
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "gone@host1" not in status
|
||||
assert monitor.get_worker_status() == {}
|
||||
|
||||
def test_heartbeat_refreshes_stale_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["primary@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
assert monitor.get_worker_status()["primary@host1"] is False
|
||||
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
assert monitor.get_worker_status()["primary@host1"] is True
|
||||
|
||||
def test_ignores_empty_hostname(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({})
|
||||
monitor._on_heartbeat({"hostname": ""})
|
||||
monitor._on_offline({})
|
||||
|
||||
assert monitor.get_worker_status() == {}
|
||||
|
||||
def test_returns_full_hostname_as_key(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@my-long-host.local"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "docprocessing@my-long-host.local" in status
|
||||
|
||||
def test_start_is_idempotent(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
# Mock the thread so we don't actually start one
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
monitor._thread = mock_thread
|
||||
monitor._running = True
|
||||
|
||||
# Second start should be a no-op
|
||||
monitor.start()
|
||||
# Thread constructor should not have been called again
|
||||
assert monitor._thread is mock_thread
|
||||
|
||||
def test_thread_safety(self) -> None:
|
||||
"""get_worker_status should not raise even if heartbeats arrive concurrently."""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
status = monitor.get_worker_status()
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
status2 = monitor.get_worker_status()
|
||||
assert status == status2
|
||||
|
||||
|
||||
class TestWorkerHealthCollector:
|
||||
def test_returns_empty_when_no_monitor(self) -> None:
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_collects_active_workers(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_heartbeat({"hostname": "docfetching@host1"})
|
||||
monitor._on_heartbeat({"hostname": "monitoring@host1"})
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 2
|
||||
|
||||
active = families[0]
|
||||
assert active.name == "onyx_celery_active_worker_count"
|
||||
assert active.samples[0].value == 3
|
||||
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 3
|
||||
# Labels use short names (before @)
|
||||
labels = {s.labels["worker"] for s in up.samples}
|
||||
assert labels == {"primary", "docfetching", "monitoring"}
|
||||
for sample in up.samples:
|
||||
assert sample.value == 1
|
||||
|
||||
def test_reports_dead_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["monitoring@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
active = families[0]
|
||||
assert active.samples[0].value == 1
|
||||
|
||||
up = families[1]
|
||||
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
|
||||
assert samples_by_name["primary"] == 1
|
||||
assert samples_by_name["monitoring"] == 0
|
||||
|
||||
def test_empty_monitor_returns_zero(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 2
|
||||
active = families[0]
|
||||
assert active.samples[0].value == 0
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 0
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -19,8 +18,7 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
return Emitter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -39,22 +39,6 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,20 +39,6 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,23 +39,6 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -66,3 +66,10 @@ DB_READONLY_PASSWORD=password
|
||||
# Show extra/uncommon connectors
|
||||
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
|
||||
SHOW_EXTRA_CONNECTORS=False
|
||||
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
#SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
@@ -35,10 +35,6 @@ USER_AUTH_SECRET=""
|
||||
|
||||
## Chat Configuration
|
||||
# HARD_DELETE_CHATS=
|
||||
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
|
||||
# Default per-user upload size limit (MB) when no admin value is set.
|
||||
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
|
||||
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
|
||||
|
||||
## Base URL for redirects
|
||||
# WEB_DOMAIN=
|
||||
@@ -46,6 +42,13 @@ USER_AUTH_SECRET=""
|
||||
## Enterprise Features, requires a paid plan and licenses
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
|
||||
|
||||
## User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
# SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
|
||||
################################################################################
|
||||
## SERVICES CONFIGURATIONS
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.37
|
||||
version: 0.4.36
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -63,22 +63,6 @@ data:
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
# timeout settings
|
||||
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
|
||||
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
|
||||
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
location ~ ^/(api|openapi\.json)(/.*)?$ {
|
||||
rewrite ^/api(/.*)$ $1 break;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
|
||||
@@ -282,7 +282,7 @@ nginx:
|
||||
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
|
||||
# Workaround: Helm upgrade will restart if the following annotation value changes.
|
||||
podAnnotations:
|
||||
onyx.app/nginx-config-version: "3"
|
||||
onyx.app/nginx-config-version: "2"
|
||||
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
@@ -1285,5 +1285,11 @@ configMap:
|
||||
DOMAIN: "localhost"
|
||||
# Chat Configs
|
||||
HARD_DELETE_CHATS: ""
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
128
greptile.json
Normal file
128
greptile.json
Normal file
@@ -0,0 +1,128 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly."
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of TODO(name): ... or TODO(1234): ..."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Remove temporary debugging code before merging to production, especially tenant-specific debugging logs."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ Some commands require external tools to be installed and configured:
|
||||
- **uv** - Required for `backend` commands
|
||||
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
|
||||
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci`, `cherry-pick`, and `trace` commands
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
|
||||
- Install from [cli.github.com](https://cli.github.com/)
|
||||
- Authenticate with `gh auth login`
|
||||
|
||||
@@ -412,62 +412,6 @@ The `compare` subcommand writes a `summary.json` alongside the report with aggre
|
||||
counts (changed, added, removed, unchanged). The HTML report is only generated when
|
||||
visual differences are detected.
|
||||
|
||||
### `trace` - View Playwright Traces from CI
|
||||
|
||||
Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with `playwright show-trace`. Traces are only generated for failing tests
|
||||
(`retain-on-failure`).
|
||||
|
||||
```shell
|
||||
ods trace [run-id-or-url]
|
||||
```
|
||||
|
||||
The run can be specified as a numeric run ID, a full GitHub Actions URL, or
|
||||
omitted to find the latest Playwright run for the current branch.
|
||||
|
||||
**Flags:**
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--branch`, `-b` | | Find latest run for this branch |
|
||||
| `--pr` | | Find latest run for this PR number |
|
||||
| `--project`, `-p` | | Filter to a specific project (`admin`, `exclusive`, `lite`) |
|
||||
| `--list`, `-l` | `false` | List available traces without opening |
|
||||
| `--no-open` | `false` | Download traces but don't open them |
|
||||
|
||||
When multiple traces are found, an interactive picker lets you select which
|
||||
traces to open. Use arrow keys or `j`/`k` to navigate, `space` to toggle,
|
||||
`a` to select all, `n` to deselect all, and `enter` to open. Falls back to a
|
||||
plain-text prompt when no TTY is available.
|
||||
|
||||
Downloaded artifacts are cached in `/tmp/ods-traces/<run-id>/` so repeated
|
||||
invocations for the same run are instant.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Latest run for the current branch
|
||||
ods trace
|
||||
|
||||
# Specific run ID
|
||||
ods trace 12345678
|
||||
|
||||
# Full GitHub Actions URL
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
|
||||
# Latest run for a PR
|
||||
ods trace --pr 9500
|
||||
|
||||
# Latest run for a specific branch
|
||||
ods trace --branch main
|
||||
|
||||
# Only download admin project traces
|
||||
ods trace --project admin
|
||||
|
||||
# List traces without opening
|
||||
ods trace --list
|
||||
```
|
||||
|
||||
### Testing Changes Locally (Dry Run)
|
||||
|
||||
Both `run-ci` and `cherry-pick` support `--dry-run` to test without making remote changes:
|
||||
|
||||
@@ -55,7 +55,6 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
cmd.AddCommand(NewTraceCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -1,556 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/git"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/tui"
|
||||
)
|
||||
|
||||
const playwrightWorkflow = "Run Playwright Tests"
|
||||
|
||||
// TraceOptions holds options for the trace command
|
||||
type TraceOptions struct {
|
||||
Branch string
|
||||
PR string
|
||||
Project string
|
||||
List bool
|
||||
NoOpen bool
|
||||
}
|
||||
|
||||
// traceInfo describes a single trace.zip found in the downloaded artifacts.
|
||||
type traceInfo struct {
|
||||
Path string // absolute path to trace.zip
|
||||
Project string // project group extracted from artifact dir (e.g. "admin", "admin-shard-1")
|
||||
TestDir string // test directory name (human-readable-ish)
|
||||
}
|
||||
|
||||
// NewTraceCommand creates a new trace command
|
||||
func NewTraceCommand() *cobra.Command {
|
||||
opts := &TraceOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "trace [run-id-or-url]",
|
||||
Short: "Download and view Playwright traces from GitHub Actions",
|
||||
Long: `Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with 'playwright show-trace'.
|
||||
|
||||
The run can be specified as:
|
||||
- A GitHub Actions run ID (numeric)
|
||||
- A full GitHub Actions run URL
|
||||
- Omitted, to find the latest Playwright run for the current branch
|
||||
|
||||
You can also look up the latest run by branch name or PR number.
|
||||
|
||||
Examples:
|
||||
ods trace # latest run for current branch
|
||||
ods trace 12345678 # specific run ID
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
ods trace --pr 9500 # latest run for PR #9500
|
||||
ods trace --branch main # latest run for main branch
|
||||
ods trace --project admin # only download admin project traces
|
||||
ods trace --list # list available traces without opening`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runTrace(args, opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Find latest run for this branch")
|
||||
cmd.Flags().StringVar(&opts.PR, "pr", "", "Find latest run for this PR number")
|
||||
cmd.Flags().StringVarP(&opts.Project, "project", "p", "", "Filter to a specific project (admin, exclusive, lite)")
|
||||
cmd.Flags().BoolVarP(&opts.List, "list", "l", false, "List available traces without opening")
|
||||
cmd.Flags().BoolVar(&opts.NoOpen, "no-open", false, "Download traces but don't open them")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ghRun represents a GitHub Actions workflow run from `gh run list`
|
||||
type ghRun struct {
|
||||
DatabaseID int64 `json:"databaseId"`
|
||||
Status string `json:"status"`
|
||||
Conclusion string `json:"conclusion"`
|
||||
HeadBranch string `json:"headBranch"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func runTrace(args []string, opts *TraceOptions) {
|
||||
git.CheckGitHubCLI()
|
||||
|
||||
runID, err := resolveRunID(args, opts)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve run: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("Using run ID: %s", runID)
|
||||
|
||||
destDir, err := downloadTraceArtifacts(runID, opts.Project)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to download artifacts: %v", err)
|
||||
}
|
||||
|
||||
traces, err := findTraceInfos(destDir, runID)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find traces: %v", err)
|
||||
}
|
||||
|
||||
if len(traces) == 0 {
|
||||
log.Info("No trace files found in the downloaded artifacts.")
|
||||
log.Info("Traces are only generated for failing tests (retain-on-failure).")
|
||||
return
|
||||
}
|
||||
|
||||
projects := groupByProject(traces)
|
||||
|
||||
if opts.List || opts.NoOpen {
|
||||
printTraceList(traces, projects)
|
||||
fmt.Printf("\nTraces downloaded to: %s\n", destDir)
|
||||
return
|
||||
}
|
||||
|
||||
if len(traces) == 1 {
|
||||
openTraces(traces)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
selected := selectTraces(traces, projects)
|
||||
if len(selected) == 0 {
|
||||
return
|
||||
}
|
||||
openTraces(selected)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRunID determines the run ID from the provided arguments and options.
|
||||
func resolveRunID(args []string, opts *TraceOptions) (string, error) {
|
||||
if len(args) == 1 {
|
||||
return parseRunIDFromArg(args[0])
|
||||
}
|
||||
|
||||
if opts.PR != "" {
|
||||
return findLatestRunForPR(opts.PR)
|
||||
}
|
||||
|
||||
branch := opts.Branch
|
||||
if branch == "" {
|
||||
var err error
|
||||
branch, err = git.GetCurrentBranch()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current branch: %w", err)
|
||||
}
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("detached HEAD; specify a --branch, --pr, or run ID")
|
||||
}
|
||||
log.Infof("Using current branch: %s", branch)
|
||||
}
|
||||
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
var runURLPattern = regexp.MustCompile(`/actions/runs/(\d+)`)
|
||||
|
||||
// parseRunIDFromArg extracts a run ID from either a numeric string or a full URL.
|
||||
func parseRunIDFromArg(arg string) (string, error) {
|
||||
if matched, _ := regexp.MatchString(`^\d+$`, arg); matched {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
matches := runURLPattern.FindStringSubmatch(arg)
|
||||
if matches != nil {
|
||||
return matches[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse run ID from %q; expected a numeric ID or GitHub Actions URL", arg)
|
||||
}
|
||||
|
||||
// findLatestRunForBranch finds the most recent Playwright workflow run for a branch.
|
||||
func findLatestRunForBranch(branch string) (string, error) {
|
||||
log.Infof("Looking up latest Playwright run for branch: %s", branch)
|
||||
|
||||
cmd := exec.Command("gh", "run", "list",
|
||||
"--workflow", playwrightWorkflow,
|
||||
"--branch", branch,
|
||||
"--limit", "1",
|
||||
"--json", "databaseId,status,conclusion,headBranch,url",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh run list failed")
|
||||
}
|
||||
|
||||
var runs []ghRun
|
||||
if err := json.Unmarshal(output, &runs); err != nil {
|
||||
return "", fmt.Errorf("failed to parse run list: %w", err)
|
||||
}
|
||||
|
||||
if len(runs) == 0 {
|
||||
return "", fmt.Errorf("no Playwright runs found for branch %q", branch)
|
||||
}
|
||||
|
||||
run := runs[0]
|
||||
log.Infof("Found run: %s (status: %s, conclusion: %s)", run.URL, run.Status, run.Conclusion)
|
||||
return fmt.Sprintf("%d", run.DatabaseID), nil
|
||||
}
|
||||
|
||||
// findLatestRunForPR finds the most recent Playwright workflow run for a PR.
|
||||
func findLatestRunForPR(prNumber string) (string, error) {
|
||||
log.Infof("Looking up branch for PR #%s", prNumber)
|
||||
|
||||
cmd := exec.Command("gh", "pr", "view", prNumber,
|
||||
"--json", "headRefName",
|
||||
"--jq", ".headRefName",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh pr view failed")
|
||||
}
|
||||
|
||||
branch := strings.TrimSpace(string(output))
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("could not determine branch for PR #%s", prNumber)
|
||||
}
|
||||
|
||||
log.Infof("PR #%s is on branch: %s", prNumber, branch)
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
// downloadTraceArtifacts downloads playwright trace artifacts for a run.
|
||||
// Returns the path to the download directory.
|
||||
func downloadTraceArtifacts(runID string, project string) (string, error) {
|
||||
cacheKey := runID
|
||||
if project != "" {
|
||||
cacheKey = runID + "-" + project
|
||||
}
|
||||
destDir := filepath.Join(os.TempDir(), "ods-traces", cacheKey)
|
||||
|
||||
// Reuse a previous download if traces exist
|
||||
if info, err := os.Stat(destDir); err == nil && info.IsDir() {
|
||||
traces, _ := findTraces(destDir)
|
||||
if len(traces) > 0 {
|
||||
log.Infof("Using cached download at %s", destDir)
|
||||
return destDir, nil
|
||||
}
|
||||
_ = os.RemoveAll(destDir)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory %s: %w", destDir, err)
|
||||
}
|
||||
|
||||
ghArgs := []string{"run", "download", runID, "--dir", destDir}
|
||||
|
||||
if project != "" {
|
||||
ghArgs = append(ghArgs, "--pattern", fmt.Sprintf("playwright-test-results-%s-*", project))
|
||||
} else {
|
||||
ghArgs = append(ghArgs, "--pattern", "playwright-test-results-*")
|
||||
}
|
||||
|
||||
log.Infof("Downloading trace artifacts...")
|
||||
log.Debugf("Running: gh %s", strings.Join(ghArgs, " "))
|
||||
|
||||
cmd := exec.Command("gh", ghArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
_ = os.RemoveAll(destDir)
|
||||
return "", fmt.Errorf("gh run download failed: %w\nMake sure the run ID is correct and the artifacts haven't expired (30 day retention)", err)
|
||||
}
|
||||
|
||||
return destDir, nil
|
||||
}
|
||||
|
||||
// findTraces recursively finds all trace.zip files under a directory.
|
||||
func findTraces(root string) ([]string, error) {
|
||||
var traces []string
|
||||
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && info.Name() == "trace.zip" {
|
||||
traces = append(traces, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// findTraceInfos walks the download directory and returns structured trace info.
|
||||
// Expects: destDir/{artifact-dir}/{test-dir}/trace.zip
|
||||
func findTraceInfos(destDir, runID string) ([]traceInfo, error) {
|
||||
var traces []traceInfo
|
||||
err := filepath.Walk(destDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() || info.Name() != "trace.zip" {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, _ := filepath.Rel(destDir, path)
|
||||
parts := strings.SplitN(rel, string(filepath.Separator), 3)
|
||||
|
||||
artifactDir := ""
|
||||
testDir := filepath.Base(filepath.Dir(path))
|
||||
if len(parts) >= 2 {
|
||||
artifactDir = parts[0]
|
||||
testDir = parts[1]
|
||||
}
|
||||
|
||||
traces = append(traces, traceInfo{
|
||||
Path: path,
|
||||
Project: extractProject(artifactDir, runID),
|
||||
TestDir: testDir,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
sort.Slice(traces, func(i, j int) bool {
|
||||
pi, pj := projectSortKey(traces[i].Project), projectSortKey(traces[j].Project)
|
||||
if pi != pj {
|
||||
return pi < pj
|
||||
}
|
||||
return traces[i].TestDir < traces[j].TestDir
|
||||
})
|
||||
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// extractProject derives a project group from an artifact directory name.
|
||||
// e.g. "playwright-test-results-admin-12345" -> "admin"
|
||||
//
|
||||
// "playwright-test-results-admin-shard-1-12345" -> "admin-shard-1"
|
||||
func extractProject(artifactDir, runID string) string {
|
||||
name := strings.TrimPrefix(artifactDir, "playwright-test-results-")
|
||||
name = strings.TrimSuffix(name, "-"+runID)
|
||||
if name == "" {
|
||||
return artifactDir
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// projectSortKey returns a sort-friendly key that orders admin < exclusive < lite.
|
||||
func projectSortKey(project string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(project, "admin"):
|
||||
return "0-" + project
|
||||
case strings.HasPrefix(project, "exclusive"):
|
||||
return "1-" + project
|
||||
case strings.HasPrefix(project, "lite"):
|
||||
return "2-" + project
|
||||
default:
|
||||
return "3-" + project
|
||||
}
|
||||
}
|
||||
|
||||
// groupByProject returns an ordered list of unique project names found in traces.
|
||||
func groupByProject(traces []traceInfo) []string {
|
||||
seen := map[string]bool{}
|
||||
var projects []string
|
||||
for _, t := range traces {
|
||||
if !seen[t.Project] {
|
||||
seen[t.Project] = true
|
||||
projects = append(projects, t.Project)
|
||||
}
|
||||
}
|
||||
sort.Slice(projects, func(i, j int) bool {
|
||||
return projectSortKey(projects[i]) < projectSortKey(projects[j])
|
||||
})
|
||||
return projects
|
||||
}
|
||||
|
||||
// printTraceList displays traces grouped by project.
|
||||
func printTraceList(traces []traceInfo, projects []string) {
|
||||
fmt.Printf("\nFound %d trace(s) across %d project(s):\n", len(traces), len(projects))
|
||||
|
||||
idx := 1
|
||||
for _, proj := range projects {
|
||||
count := 0
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
count++
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n %s (%d):\n", proj, count)
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
fmt.Printf(" [%2d] %s\n", idx, t.TestDir)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// selectTraces tries the TUI picker first, falling back to a plain-text
|
||||
// prompt when the terminal cannot be initialised (e.g. piped output).
|
||||
func selectTraces(traces []traceInfo, projects []string) []traceInfo {
|
||||
// Build picker groups in the same order as the sorted traces slice.
|
||||
var groups []tui.PickerGroup
|
||||
for _, proj := range projects {
|
||||
var items []string
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
items = append(items, t.TestDir)
|
||||
}
|
||||
}
|
||||
groups = append(groups, tui.PickerGroup{Label: proj, Items: items})
|
||||
}
|
||||
|
||||
indices, err := tui.Pick(groups)
|
||||
if err != nil {
|
||||
// Terminal not available — fall back to text prompt
|
||||
log.Debugf("TUI picker unavailable: %v", err)
|
||||
printTraceList(traces, projects)
|
||||
return promptTraceSelection(traces, projects)
|
||||
}
|
||||
if indices == nil {
|
||||
return nil // user cancelled
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// promptTraceSelection asks the user which traces to open via plain text.
|
||||
// Accepts numbers (1,3,5), ranges (1-5), "all", or a project name.
|
||||
func promptTraceSelection(traces []traceInfo, projects []string) []traceInfo {
|
||||
fmt.Printf("\nOpen which traces? (e.g. 1,3,5 | 1-5 | all | %s): ", strings.Join(projects, " | "))
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read input: %v", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" || strings.EqualFold(input, "all") {
|
||||
return traces
|
||||
}
|
||||
|
||||
// Check if input matches a project name
|
||||
for _, proj := range projects {
|
||||
if strings.EqualFold(input, proj) {
|
||||
var selected []traceInfo
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
selected = append(selected, t)
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as number/range selection
|
||||
indices := parseTraceSelection(input, len(traces))
|
||||
if len(indices) == 0 {
|
||||
log.Warn("No valid selection; opening all traces")
|
||||
return traces
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// parseTraceSelection parses a comma-separated list of numbers and ranges into
|
||||
// 0-based indices. Input is 1-indexed (matches display). Out-of-range values
|
||||
// are silently ignored.
|
||||
func parseTraceSelection(input string, max int) []int {
|
||||
var result []int
|
||||
seen := map[int]bool{}
|
||||
|
||||
for _, part := range strings.Split(input, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if idx := strings.Index(part, "-"); idx > 0 {
|
||||
lo, err1 := strconv.Atoi(strings.TrimSpace(part[:idx]))
|
||||
hi, err2 := strconv.Atoi(strings.TrimSpace(part[idx+1:]))
|
||||
if err1 != nil || err2 != nil {
|
||||
continue
|
||||
}
|
||||
for i := lo; i <= hi; i++ {
|
||||
zi := i - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
zi := n - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// openTraces opens the selected traces with playwright show-trace,
|
||||
// running npx from the web/ directory to use the project's Playwright version.
|
||||
func openTraces(traces []traceInfo) {
|
||||
tracePaths := make([]string, len(traces))
|
||||
for i, t := range traces {
|
||||
tracePaths[i] = t.Path
|
||||
}
|
||||
|
||||
args := append([]string{"playwright", "show-trace"}, tracePaths...)
|
||||
|
||||
log.Infof("Opening %d trace(s) with playwright show-trace...", len(traces))
|
||||
cmd := exec.Command("npx", args...)
|
||||
|
||||
// Run from web/ to pick up the locally-installed Playwright version
|
||||
if root, err := paths.GitRoot(); err == nil {
|
||||
cmd.Dir = filepath.Join(root, "web")
|
||||
}
|
||||
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
// Normal exit (e.g. user closed the window) — just log and return
|
||||
// so the picker loop can continue.
|
||||
log.Debugf("playwright exited with code %d", exitErr.ExitCode())
|
||||
return
|
||||
}
|
||||
log.Errorf("playwright show-trace failed: %v\nMake sure Playwright is installed (npx playwright install)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ghError wraps a gh CLI error with stderr output.
|
||||
func ghError(err error, msg string) error {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return fmt.Errorf("%s: %w: %s", msg, err, string(exitErr.Stderr))
|
||||
}
|
||||
return fmt.Errorf("%s: %w", msg, err)
|
||||
}
|
||||
@@ -3,19 +3,13 @@ module github.com/onyx-dot-app/onyx/tools/ods
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/jmelahman/tag v0.5.2
|
||||
github.com/sirupsen/logrus v1.9.4
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/term v0.41.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,68 +1,30 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
|
||||
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
|
||||
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -1,419 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
)
|
||||
|
||||
// PickerGroup represents a labelled group of selectable items.
|
||||
type PickerGroup struct {
|
||||
Label string
|
||||
Items []string
|
||||
}
|
||||
|
||||
// entry is a single row in the picker (either a group header or an item).
|
||||
type entry struct {
|
||||
label string
|
||||
isHeader bool
|
||||
selected bool
|
||||
groupIdx int
|
||||
flatIdx int // index across all items (ignoring headers), -1 for headers
|
||||
}
|
||||
|
||||
// Pick shows a full-screen grouped multi-select picker.
|
||||
// All items start deselected. Returns the flat indices of selected items
|
||||
// (0-based, spanning all groups in order). Returns nil if cancelled.
|
||||
// Returns a non-nil error if the terminal cannot be initialised, in which
|
||||
// case the caller should fall back to a simpler prompt.
|
||||
func Pick(groups []PickerGroup) ([]int, error) {
|
||||
screen, err := tcell.NewScreen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := screen.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer screen.Fini()
|
||||
|
||||
entries := buildEntries(groups)
|
||||
totalItems := countItems(entries)
|
||||
cursor := firstSelectableIndex(entries)
|
||||
offset := 0
|
||||
|
||||
for {
|
||||
w, h := screen.Size()
|
||||
selectedCount := countSelected(entries)
|
||||
|
||||
drawPicker(screen, entries, groups, cursor, offset, w, h, selectedCount, totalItems)
|
||||
screen.Show()
|
||||
|
||||
ev := screen.PollEvent()
|
||||
switch ev := ev.(type) {
|
||||
case *tcell.EventResize:
|
||||
screen.Sync()
|
||||
case *tcell.EventKey:
|
||||
switch action := keyAction(ev); action {
|
||||
case actionQuit:
|
||||
return nil, nil
|
||||
case actionConfirm:
|
||||
if countSelected(entries) > 0 {
|
||||
return collectSelected(entries), nil
|
||||
}
|
||||
case actionUp:
|
||||
if cursor > 0 {
|
||||
cursor--
|
||||
}
|
||||
case actionDown:
|
||||
if cursor < len(entries)-1 {
|
||||
cursor++
|
||||
}
|
||||
case actionTop:
|
||||
cursor = 0
|
||||
case actionBottom:
|
||||
if len(entries) == 0 {
|
||||
cursor = 0
|
||||
} else {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionPageUp:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor -= listHeight
|
||||
if cursor < 0 {
|
||||
cursor = 0
|
||||
}
|
||||
case actionPageDown:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor += listHeight
|
||||
if cursor >= len(entries) {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionToggle:
|
||||
toggleAtCursor(entries, cursor)
|
||||
case actionAll:
|
||||
setAll(entries, true)
|
||||
case actionNone:
|
||||
setAll(entries, false)
|
||||
}
|
||||
|
||||
// Keep the cursor visible
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
if cursor < offset {
|
||||
offset = cursor
|
||||
}
|
||||
if cursor >= offset+listHeight {
|
||||
offset = cursor - listHeight + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- actions ----------------------------------------------------------------
|
||||
|
||||
type action int
|
||||
|
||||
const (
|
||||
actionNoop action = iota
|
||||
actionQuit
|
||||
actionConfirm
|
||||
actionUp
|
||||
actionDown
|
||||
actionTop
|
||||
actionBottom
|
||||
actionPageUp
|
||||
actionPageDown
|
||||
actionToggle
|
||||
actionAll
|
||||
actionNone
|
||||
)
|
||||
|
||||
func keyAction(ev *tcell.EventKey) action {
|
||||
switch ev.Key() {
|
||||
case tcell.KeyEscape, tcell.KeyCtrlC:
|
||||
return actionQuit
|
||||
case tcell.KeyEnter:
|
||||
return actionConfirm
|
||||
case tcell.KeyUp:
|
||||
return actionUp
|
||||
case tcell.KeyDown:
|
||||
return actionDown
|
||||
case tcell.KeyHome:
|
||||
return actionTop
|
||||
case tcell.KeyEnd:
|
||||
return actionBottom
|
||||
case tcell.KeyPgUp:
|
||||
return actionPageUp
|
||||
case tcell.KeyPgDn:
|
||||
return actionPageDown
|
||||
case tcell.KeyRune:
|
||||
switch ev.Rune() {
|
||||
case 'q':
|
||||
return actionQuit
|
||||
case ' ':
|
||||
return actionToggle
|
||||
case 'j':
|
||||
return actionDown
|
||||
case 'k':
|
||||
return actionUp
|
||||
case 'g':
|
||||
return actionTop
|
||||
case 'G':
|
||||
return actionBottom
|
||||
case 'a':
|
||||
return actionAll
|
||||
case 'n':
|
||||
return actionNone
|
||||
}
|
||||
}
|
||||
return actionNoop
|
||||
}
|
||||
|
||||
// --- data helpers ------------------------------------------------------------
|
||||
|
||||
func buildEntries(groups []PickerGroup) []entry {
|
||||
var entries []entry
|
||||
flat := 0
|
||||
for gi, g := range groups {
|
||||
entries = append(entries, entry{
|
||||
label: g.Label,
|
||||
isHeader: true,
|
||||
groupIdx: gi,
|
||||
flatIdx: -1,
|
||||
})
|
||||
for _, item := range g.Items {
|
||||
entries = append(entries, entry{
|
||||
label: item,
|
||||
isHeader: false,
|
||||
selected: false,
|
||||
groupIdx: gi,
|
||||
flatIdx: flat,
|
||||
})
|
||||
flat++
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func firstSelectableIndex(entries []entry) int {
|
||||
for i, e := range entries {
|
||||
if !e.isHeader {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countItems(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func countSelected(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func collectSelected(entries []entry) []int {
|
||||
var result []int
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
result = append(result, e.flatIdx)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toggleAtCursor(entries []entry, cursor int) {
|
||||
if cursor < 0 || cursor >= len(entries) {
|
||||
return
|
||||
}
|
||||
e := entries[cursor]
|
||||
if e.isHeader {
|
||||
// Toggle entire group: if all selected -> deselect all, else select all
|
||||
allSelected := true
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx && !e2.selected {
|
||||
allSelected = false
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader && entries[i].groupIdx == e.groupIdx {
|
||||
entries[i].selected = !allSelected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
entries[cursor].selected = !entries[cursor].selected
|
||||
}
|
||||
}
|
||||
|
||||
func setAll(entries []entry, selected bool) {
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader {
|
||||
entries[i].selected = selected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- drawing ----------------------------------------------------------------
|
||||
|
||||
const (
|
||||
headerLines = 2 // title + blank line
|
||||
footerLines = 2 // blank line + keybinds
|
||||
)
|
||||
|
||||
var (
|
||||
styleDefault = tcell.StyleDefault
|
||||
styleTitle = tcell.StyleDefault.Bold(true)
|
||||
styleGroup = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal)
|
||||
styleGroupCur = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal).Reverse(true)
|
||||
styleCheck = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true)
|
||||
styleUncheck = tcell.StyleDefault.Dim(true)
|
||||
styleItem = tcell.StyleDefault
|
||||
styleItemCur = tcell.StyleDefault.Bold(true).Underline(true)
|
||||
styleCheckCur = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true).Underline(true)
|
||||
styleUncheckCur = tcell.StyleDefault.Dim(true).Underline(true)
|
||||
styleFooter = tcell.StyleDefault.Dim(true)
|
||||
)
|
||||
|
||||
func drawPicker(
|
||||
screen tcell.Screen,
|
||||
entries []entry,
|
||||
groups []PickerGroup,
|
||||
cursor, offset, w, h, selectedCount, totalItems int,
|
||||
) {
|
||||
screen.Clear()
|
||||
|
||||
// Title
|
||||
title := fmt.Sprintf(" Select traces to open (%d/%d selected)", selectedCount, totalItems)
|
||||
drawLine(screen, 0, 0, w, title, styleTitle)
|
||||
|
||||
// List area
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
|
||||
for i := 0; i < listHeight; i++ {
|
||||
ei := offset + i
|
||||
if ei >= len(entries) {
|
||||
break
|
||||
}
|
||||
y := headerLines + i
|
||||
renderEntry(screen, entries, groups, ei, cursor, w, y)
|
||||
}
|
||||
|
||||
// Scrollbar hint
|
||||
if len(entries) > listHeight {
|
||||
drawScrollbar(screen, w-1, headerLines, listHeight, offset, len(entries))
|
||||
}
|
||||
|
||||
// Footer
|
||||
footerY := h - 1
|
||||
footer := " ↑/↓ move space toggle a all n none enter open q/esc quit"
|
||||
drawLine(screen, 0, footerY, w, footer, styleFooter)
|
||||
}
|
||||
|
||||
func renderEntry(screen tcell.Screen, entries []entry, groups []PickerGroup, ei, cursor, w, y int) {
|
||||
e := entries[ei]
|
||||
isCursor := ei == cursor
|
||||
|
||||
if e.isHeader {
|
||||
groupSelected := 0
|
||||
groupTotal := 0
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx {
|
||||
groupTotal++
|
||||
if e2.selected {
|
||||
groupSelected++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
label := fmt.Sprintf(" %s (%d/%d)", e.label, groupSelected, groupTotal)
|
||||
style := styleGroup
|
||||
if isCursor {
|
||||
style = styleGroupCur
|
||||
}
|
||||
drawLine(screen, 0, y, w, label, style)
|
||||
return
|
||||
}
|
||||
|
||||
// Item row: " [x] label" or " > [x] label"
|
||||
prefix := " "
|
||||
if isCursor {
|
||||
prefix = " > "
|
||||
}
|
||||
|
||||
check := "[ ]"
|
||||
cStyle := styleUncheck
|
||||
iStyle := styleItem
|
||||
if isCursor {
|
||||
cStyle = styleUncheckCur
|
||||
iStyle = styleItemCur
|
||||
}
|
||||
if e.selected {
|
||||
check = "[x]"
|
||||
cStyle = styleCheck
|
||||
if isCursor {
|
||||
cStyle = styleCheckCur
|
||||
}
|
||||
}
|
||||
|
||||
x := drawStr(screen, 0, y, w, prefix, iStyle)
|
||||
x = drawStr(screen, x, y, w, check, cStyle)
|
||||
drawStr(screen, x, y, w, " "+e.label, iStyle)
|
||||
}
|
||||
|
||||
func drawScrollbar(screen tcell.Screen, x, top, height, offset, total int) {
|
||||
if total <= height || height < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
thumbSize := max(1, height*height/total)
|
||||
thumbPos := top + offset*height/total
|
||||
|
||||
for y := top; y < top+height; y++ {
|
||||
ch := '│'
|
||||
style := styleDefault.Dim(true)
|
||||
if y >= thumbPos && y < thumbPos+thumbSize {
|
||||
ch = '┃'
|
||||
style = styleDefault
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawLine fills an entire row starting at x=startX, padding to width w.
|
||||
func drawLine(screen tcell.Screen, startX, y, w int, s string, style tcell.Style) {
|
||||
x := drawStr(screen, startX, y, w, s, style)
|
||||
// Clear the rest of the line
|
||||
for ; x < w; x++ {
|
||||
screen.SetContent(x, y, ' ', nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawStr writes a string at (x, y) up to maxX and returns the next x position.
|
||||
func drawStr(screen tcell.Screen, x, y, maxX int, s string, style tcell.Style) int {
|
||||
for _, ch := range s {
|
||||
if x >= maxX {
|
||||
break
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
x++
|
||||
}
|
||||
return x
|
||||
}
|
||||
@@ -342,9 +342,9 @@ visible text in the DOM (e.g., `title`, `description`, `label`) should be typed
|
||||
|
||||
```typescript
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
|
||||
// ✅ Good — new components accept string | RichStr and render via Text
|
||||
// ✅ Good — new components accept string | RichStr
|
||||
interface InfoCardProps {
|
||||
title: string | RichStr;
|
||||
description?: string | RichStr;
|
||||
@@ -353,9 +353,9 @@ interface InfoCardProps {
|
||||
function InfoCard({ title, description }: InfoCardProps) {
|
||||
return (
|
||||
<div>
|
||||
<Text font="main-ui-action">{title}</Text>
|
||||
<Text font="main-ui-action">{resolveStr(title)}</Text>
|
||||
{description && (
|
||||
<Text font="secondary-body" color="text-03">{description}</Text>
|
||||
<Text font="secondary-body" color="text-03">{resolveStr(description)}</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -4,15 +4,11 @@ import {
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { cn } from "@opal/utils";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -22,13 +18,13 @@ import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
type ButtonContentProps =
|
||||
| {
|
||||
icon?: IconFunctionComponent;
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: never;
|
||||
}
|
||||
| {
|
||||
icon: IconFunctionComponent;
|
||||
children?: string | RichStr;
|
||||
children?: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: boolean;
|
||||
};
|
||||
@@ -73,24 +69,15 @@ function Button({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
responsiveHideText ? (
|
||||
<span className="hidden md:inline whitespace-nowrap">
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
</span>
|
||||
) : (
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
nowrap
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
)
|
||||
<span
|
||||
className={cn(
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body " : "font-secondary-body",
|
||||
responsiveHideText && "hidden md:inline"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
|
||||
@@ -4,8 +4,7 @@ import {
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
@@ -17,12 +16,12 @@ import { Button } from "@opal/components/buttons/button/components";
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FilterButtonProps
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state" | "children"> {
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
|
||||
/** Left icon — always visible. */
|
||||
icon: IconFunctionComponent;
|
||||
|
||||
/** Label text between icon and trailing indicator. */
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
|
||||
/** Whether the filter has an active selection. @default false */
|
||||
active?: boolean;
|
||||
@@ -69,9 +68,9 @@ function FilterButton({
|
||||
<Interactive.Container type="button">
|
||||
<div className="interactive-foreground flex flex-row items-center gap-1">
|
||||
{iconWrapper(Icon, "lg", true)}
|
||||
<Text font="main-ui-action" color="inherit" nowrap>
|
||||
<span className="whitespace-nowrap font-main-ui-action">
|
||||
{children}
|
||||
</Text>
|
||||
</span>
|
||||
<div style={{ visibility: active ? "hidden" : "visible" }}>
|
||||
{iconWrapper(ChevronIcon, "lg", true)}
|
||||
</div>
|
||||
|
||||
@@ -4,12 +4,7 @@ import {
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { InteractiveContainerRoundingVariant } from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
@@ -33,17 +28,17 @@ type OpenButtonContentProps =
|
||||
| {
|
||||
foldable: true;
|
||||
icon: IconFunctionComponent;
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string | RichStr;
|
||||
children?: string;
|
||||
};
|
||||
|
||||
type OpenButtonVariant = "select-light" | "select-heavy" | "select-tinted";
|
||||
@@ -106,13 +101,14 @@ function OpenButton({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
nowrap
|
||||
<span
|
||||
className={cn(
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
</span>
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
@@ -181,7 +177,7 @@ function OpenButton({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
<Text>{resolvedTooltip}</Text>
|
||||
{resolvedTooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -4,12 +4,7 @@ import {
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
@@ -31,19 +26,19 @@ type SelectButtonContentProps =
|
||||
| {
|
||||
foldable: true;
|
||||
icon: IconFunctionComponent;
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string | RichStr;
|
||||
children: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string | RichStr;
|
||||
children?: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
};
|
||||
|
||||
@@ -84,10 +79,13 @@ function SelectButton({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<span className="opal-select-button-label">
|
||||
<Text font={isLarge ? "main-ui-body" : "secondary-body"} color="inherit">
|
||||
{children}
|
||||
</Text>
|
||||
<span
|
||||
className={cn(
|
||||
"opal-select-button-label",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
) : null;
|
||||
|
||||
@@ -139,7 +137,7 @@ function SelectButton({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
<Text>{resolvedTooltip}</Text>
|
||||
{resolvedTooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -4,9 +4,7 @@ import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgArrowRight, SvgChevronLeft, SvgChevronRight } from "@opal/icons";
|
||||
import { containerSizeVariants } from "@opal/shared";
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import { toPlainString } from "@opal/components/text/InlineMarkdown";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import * as PopoverPrimitive from "@radix-ui/react-popover";
|
||||
import {
|
||||
@@ -40,7 +38,7 @@ interface SimplePaginationProps
|
||||
/** Hides the `currentPage/totalPages` summary text between arrows. Default: `false`. */
|
||||
hidePages?: boolean;
|
||||
/** Unit label shown after the summary (e.g. `"pages"`). Always has 4px spacing. */
|
||||
units?: string | RichStr;
|
||||
units?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -65,7 +63,7 @@ interface CountPaginationProps
|
||||
/** Hides the current page number between the arrows. Default: `false`. */
|
||||
hidePages?: boolean;
|
||||
/** Unit label shown after the total count (e.g. `"items"`). Always has 4px spacing. */
|
||||
units?: string | RichStr;
|
||||
units?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -333,9 +331,7 @@ function PaginationSimple({
|
||||
}: SimplePaginationProps) {
|
||||
const handleChange = (page: number) => onChange?.(page);
|
||||
|
||||
const label = `${currentPage}/${totalPages}${
|
||||
units ? ` ${toPlainString(units)}` : ""
|
||||
}`;
|
||||
const label = `${currentPage}/${totalPages}${units ? ` ${units}` : ""}`;
|
||||
|
||||
return (
|
||||
<div {...props} className="flex items-center">
|
||||
@@ -389,16 +385,7 @@ function PaginationCount({
|
||||
{rangeStart}~{rangeEnd}
|
||||
<span className={textClasses(size, "muted")}>of</span>
|
||||
{totalItems}
|
||||
{units && (
|
||||
<span className="ml-1">
|
||||
<Text
|
||||
color="inherit"
|
||||
font={size === "sm" ? "secondary-body" : "main-ui-muted"}
|
||||
>
|
||||
{units}
|
||||
</Text>
|
||||
</span>
|
||||
)}
|
||||
{units && <span className="ml-1">{units}</span>}
|
||||
</span>
|
||||
|
||||
{/* Buttons: < [page] > */}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import "@opal/components/tag/styles.css";
|
||||
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -17,7 +16,7 @@ interface TagProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Tag label text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Color variant. Default: `"gray"`. */
|
||||
color?: TagColor;
|
||||
@@ -52,13 +51,14 @@ function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
|
||||
<Icon className={cn("opal-auxiliary-tag-icon", config.text)} />
|
||||
</div>
|
||||
)}
|
||||
<span className={cn("opal-auxiliary-tag-title px-[2px]", config.text)}>
|
||||
<Text
|
||||
font={size === "md" ? "secondary-body" : "figure-small-value"}
|
||||
color="inherit"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
<span
|
||||
className={cn(
|
||||
"opal-auxiliary-tag-title px-[2px]",
|
||||
size === "md" ? "font-secondary-body" : "font-figure-small-value",
|
||||
config.text
|
||||
)}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { RichStr } from "@opal/types";
|
||||
|
||||
const SAFE_PROTOCOL = /^https?:|^mailto:|^tel:/i;
|
||||
|
||||
const ALLOWED_ELEMENTS = ["p", "br", "a", "strong", "em", "code", "del"];
|
||||
const ALLOWED_ELEMENTS = ["p", "a", "strong", "em", "code", "del"];
|
||||
|
||||
const INLINE_COMPONENTS = {
|
||||
p: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
@@ -41,11 +41,6 @@ interface InlineMarkdownProps {
|
||||
}
|
||||
|
||||
export default function InlineMarkdown({ content }: InlineMarkdownProps) {
|
||||
// Convert \n to CommonMark hard line breaks (two trailing spaces + newline).
|
||||
// react-markdown renders these as <br />, which inherits the parent's
|
||||
// line-height for font-appropriate spacing.
|
||||
const normalized = content.replace(/\n/g, " \n");
|
||||
|
||||
return (
|
||||
<ReactMarkdown
|
||||
components={INLINE_COMPONENTS}
|
||||
@@ -53,7 +48,7 @@ export default function InlineMarkdown({ content }: InlineMarkdownProps) {
|
||||
unwrapDisallowed
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{normalized}
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -90,15 +90,15 @@ import { markdown } from "@opal/utils";
|
||||
</Text>
|
||||
```
|
||||
|
||||
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`, `\n` (newline → `<br />`).
|
||||
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`.
|
||||
|
||||
Markdown rendering uses `react-markdown` internally, restricted to inline elements only.
|
||||
`http(s)` links open in a new tab; `mailto:` and `tel:` links open natively. Inline code
|
||||
inherits the parent font size and switches to the monospace family.
|
||||
|
||||
Newlines (`\n`) are converted to `<br />` elements that inherit the parent's line-height,
|
||||
so line spacing is proportional to the font size. For full block-level markdown (code blocks,
|
||||
headings, lists), use `MinimalMarkdown` instead.
|
||||
**Note:** This is inline-only markdown. Multi-paragraph content (`"Hello\n\nWorld"`) will
|
||||
collapse into a single run of text since paragraph wrappers are stripped. For block-level
|
||||
markdown, use `MinimalMarkdown` instead.
|
||||
|
||||
### Using `RichStr` in component props
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ type TextFont =
|
||||
| "figure-keystroke";
|
||||
|
||||
type TextColor =
|
||||
| "inherit"
|
||||
| "text-01"
|
||||
| "text-02"
|
||||
| "text-03"
|
||||
@@ -61,9 +60,6 @@ interface TextProps
|
||||
/** Prevent text wrapping. */
|
||||
nowrap?: boolean;
|
||||
|
||||
/** Truncate text to N lines with ellipsis. `1` uses simple truncation; `2+` uses `-webkit-line-clamp`. */
|
||||
maxLines?: number;
|
||||
|
||||
/** Plain string or `markdown()` for inline markdown. */
|
||||
children?: string | RichStr;
|
||||
}
|
||||
@@ -93,8 +89,7 @@ const FONT_CONFIG: Record<TextFont, string> = {
|
||||
"figure-keystroke": "font-figure-keystroke",
|
||||
};
|
||||
|
||||
const COLOR_CONFIG: Record<TextColor, string | null> = {
|
||||
inherit: null,
|
||||
const COLOR_CONFIG: Record<TextColor, string> = {
|
||||
"text-01": "text-text-01",
|
||||
"text-02": "text-text-02",
|
||||
"text-03": "text-text-03",
|
||||
@@ -120,29 +115,17 @@ function Text({
|
||||
color = "text-04",
|
||||
as: Tag = "span",
|
||||
nowrap,
|
||||
maxLines,
|
||||
children,
|
||||
...rest
|
||||
}: TextProps) {
|
||||
const resolvedClassName = cn(
|
||||
FONT_CONFIG[font],
|
||||
COLOR_CONFIG[color],
|
||||
nowrap && "whitespace-nowrap",
|
||||
maxLines === 1 && "truncate",
|
||||
maxLines && maxLines > 1 && "overflow-hidden"
|
||||
nowrap && "whitespace-nowrap"
|
||||
);
|
||||
|
||||
const style =
|
||||
maxLines && maxLines > 1
|
||||
? ({
|
||||
display: "-webkit-box",
|
||||
WebkitBoxOrient: "vertical",
|
||||
WebkitLineClamp: maxLines,
|
||||
} as React.CSSProperties)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<Tag {...rest} className={resolvedClassName} style={style}>
|
||||
<Tag {...rest} className={resolvedClassName}>
|
||||
{children && resolveStr(children)}
|
||||
</Tag>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/animations/styles.css";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/disabled/styles.css";
|
||||
import React, { createContext, useContext } from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 37 46"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={cn(className, "text-[#33C19E] dark:text-white")}
|
||||
{...props}
|
||||
>
|
||||
<title>Bifrost</title>
|
||||
<path
|
||||
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgBifrost;
|
||||
@@ -24,7 +24,6 @@ export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBifrost } from "@opal/icons/bifrost";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
|
||||
@@ -3,11 +3,7 @@
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import type { ContainerSizeVariants } from "@opal/types";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -39,10 +35,10 @@ interface ContentLgProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string | RichStr;
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -100,18 +96,18 @@ function ContentLg({
|
||||
ref,
|
||||
}: ContentLgProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
|
||||
const config = CONTENT_LG_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -161,7 +157,7 @@ function ContentLg({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -178,9 +174,9 @@ function ContentLg({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={toPlainString(title)}
|
||||
title={title}
|
||||
>
|
||||
{resolveStr(title)}
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -203,9 +199,9 @@ function ContentLg({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && toPlainString(description) && (
|
||||
{description && (
|
||||
<div className="opal-content-lg-description font-secondary-body text-text-03">
|
||||
{resolveStr(description)}
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -7,11 +7,7 @@ import SvgAlertCircle from "@opal/icons/alert-circle";
|
||||
import SvgAlertTriangle from "@opal/icons/alert-triangle";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import SvgXOctagon from "@opal/icons/x-octagon";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
@@ -29,6 +25,7 @@ interface ContentMdPresetConfig {
|
||||
iconColorClass: string;
|
||||
titleFont: string;
|
||||
lineHeight: string;
|
||||
gap: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: ContainerSizeVariants;
|
||||
editButtonPadding: string;
|
||||
@@ -44,10 +41,10 @@ interface ContentMdProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Optional description text below the title. */
|
||||
description?: string | RichStr;
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -88,6 +85,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-main-content-emphasis",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-content-muted",
|
||||
@@ -100,6 +98,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-03",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
editButtonSize: "xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-ui-muted",
|
||||
@@ -112,6 +111,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "2xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-secondary-action",
|
||||
@@ -149,19 +149,19 @@ function ContentMd({
|
||||
ref,
|
||||
}: ContentMdProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_MD_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -170,6 +170,7 @@ function ContentMd({
|
||||
ref={ref}
|
||||
className="opal-content-md"
|
||||
data-interactive={withInteractive || undefined}
|
||||
style={{ gap: config.gap }}
|
||||
>
|
||||
<div
|
||||
className="opal-content-md-header"
|
||||
@@ -214,7 +215,7 @@ function ContentMd({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -229,11 +230,11 @@ function ContentMd({
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
title={toPlainString(title)}
|
||||
title={title}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{resolveStr(title)}
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -287,12 +288,12 @@ function ContentMd({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{description && toPlainString(description) && (
|
||||
{description && (
|
||||
<div
|
||||
className="opal-content-md-description font-secondary-body text-text-03"
|
||||
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
|
||||
>
|
||||
{resolveStr(description)}
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -34,7 +30,7 @@ interface ContentSmProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text (read-only — editing is not supported). */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: ContentSmSizePreset;
|
||||
@@ -122,9 +118,9 @@ function ContentSm({
|
||||
<span
|
||||
className={cn("opal-content-sm-title", config.titleFont)}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={toPlainString(title)}
|
||||
title={title}
|
||||
>
|
||||
{resolveStr(title)}
|
||||
{title}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,11 +3,7 @@
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import type { ContainerSizeVariants } from "@opal/types";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -45,10 +41,10 @@ interface ContentXlProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string | RichStr;
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -120,18 +116,18 @@ function ContentXl({
|
||||
ref,
|
||||
}: ContentXlProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
|
||||
const config = CONTENT_XL_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -218,7 +214,7 @@ function ContentXl({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(toPlainString(title));
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -235,9 +231,9 @@ function ContentXl({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={toPlainString(title)}
|
||||
title={title}
|
||||
>
|
||||
{resolveStr(title)}
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -260,9 +256,9 @@ function ContentXl({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && toPlainString(description) && (
|
||||
{description && (
|
||||
<div className="opal-content-xl-description font-secondary-body text-text-03">
|
||||
{resolveStr(description)}
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,39 +1,3 @@
|
||||
// ---------------------------------------------------------------------------
|
||||
// NOTE (@raunakab): Why Content uses resolveStr() instead of <Text>
|
||||
//
|
||||
// Content sub-components (ContentXl, ContentLg, ContentMd, ContentSm) render
|
||||
// titles and descriptions inside styled <span> elements that carry CSS classes
|
||||
// (e.g., `.opal-content-md-title`) for:
|
||||
//
|
||||
// 1. Truncation — `-webkit-box` + `-webkit-line-clamp` for single-line
|
||||
// clamping with ellipsis. This requires the text to be a DIRECT child
|
||||
// of the `-webkit-box` element. Wrapping it in a child <span> (which
|
||||
// is what <Text> renders) breaks the clamping behavior.
|
||||
//
|
||||
// 2. Pixel-exact sizing — the wrapper <span> has an explicit `height`
|
||||
// matching the font's `line-height`. Adding a child <Text> <span>
|
||||
// inside creates a double-span where the inner element's line-height
|
||||
// conflicts with the outer element's height, causing a ~4px vertical
|
||||
// offset.
|
||||
//
|
||||
// 3. Interactive color overrides — CSS selectors like
|
||||
// `.opal-content-md[data-interactive] .opal-content-md-title` set
|
||||
// `color: var(--interactive-foreground)`. <Text> with `color="inherit"`
|
||||
// can inherit this, but <Text> with any explicit color prop overrides
|
||||
// it. And the wrapper <span> needs the CSS class for the selector to
|
||||
// match — removing it breaks the cascade.
|
||||
//
|
||||
// 4. Horizontal padding — the title CSS class applies `padding: 0 0.125rem`
|
||||
// (2px). Since <Text> uses WithoutStyles (no className/style), this
|
||||
// padding cannot be applied to <Text> directly. A wrapper <div> was
|
||||
// attempted but introduced additional layout conflicts.
|
||||
//
|
||||
// For these reasons, Content uses `resolveStr()` from InlineMarkdown.tsx to
|
||||
// handle `string | RichStr` rendering. `resolveStr()` returns a ReactNode
|
||||
// that slots directly into the existing single <span> — no extra wrapper,
|
||||
// no layout conflicts, pixel-exact match with main.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
import "@opal/layouts/content/styles.css";
|
||||
import {
|
||||
ContentSm,
|
||||
@@ -53,7 +17,7 @@ import {
|
||||
type ContentMdProps,
|
||||
} from "@opal/layouts/content/ContentMd";
|
||||
import type { TagProps } from "@opal/components/tag/components";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { widthVariants } from "@opal/shared";
|
||||
import type { ExtremaSizeVariants } from "@opal/types";
|
||||
|
||||
@@ -75,10 +39,10 @@ interface ContentBaseProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string | RichStr;
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user