Compare commits

...

40 Commits

Author SHA1 Message Date
Bo-Onyx
afc163d7e5 fix(api memory): replace glibc with jemalloc for memory allocating (#9196) 2026-03-25 14:43:56 -07:00
Justin Tahara
0f214ca190 fix(ui): InputComboBox search for users/groups (#8928) 2026-03-04 12:43:39 -08:00
Justin Tahara
422ca91edc chore(ui): Update the Share Agent Modal (#8915) 2026-03-04 12:41:52 -08:00
Nikolas Garza
fe287eebb6 feat(slack): convert markdown tables to Slack-friendly format (#8999) 2026-03-04 11:54:41 -08:00
Justin Tahara
3f8ef8b465 fix(celery): Guardrail for User File Processing (#8633) 2026-03-01 10:29:55 -08:00
Justin Tahara
ed46504a1a fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 14:22:34 -08:00
Nikolas Garza
7a24b34516 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 17:27:31 -08:00
dependabot[bot]
7a7ffa9051 chore(deps): Bump mistune from 0.8.4 to 3.1.4 in /backend (#6407)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 17:27:31 -08:00
Jamison Lahman
3053ab518c chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:26:55 -08:00
justin-tahara
be38d3500f Fixing mypy 2026-02-09 15:48:53 -08:00
Justin Tahara
753a3bc093 fix(posthog): Chat metrics for Cloud (#8278) 2026-02-09 15:48:53 -08:00
Raunak Bhagat
2ba8fafe78 fix: Add explicit sizings to icons (#8018) 2026-02-06 18:15:47 -08:00
Raunak Bhagat
b77b580ebd Cherry-pick card fix 2026-02-06 18:15:40 -08:00
Nikolas Garza
3eee98b932 fix: make it more clear how to add channels to fed slack config form (#8227) 2026-02-06 16:35:46 -08:00
Nikolas Garza
a97eb02fef fix(db): null out document set and persona ownership on user deletion (#8219) 2026-02-06 16:35:46 -08:00
Justin Tahara
c5061495a2 fix(ui): Inconsistent LLM Provider Logo (#8220) 2026-02-06 13:56:57 -08:00
Justin Tahara
c20b0789ae fix(ui): Additional LLM Config update (#8174) 2026-02-06 13:56:49 -08:00
Justin Tahara
d99848717b fix(ui): Ollama Model Selection (#8091) 2026-02-06 13:53:52 -08:00
Evan Lohn
aaca55c415 fix(salesforce): cleanup logic (#8175) 2026-02-06 13:52:46 -08:00
Justin Tahara
9d7ffd1e4a fix(ui): Updating Dropdown Modal component (#8033) 2026-02-06 11:39:48 -08:00
Justin Tahara
a249161827 chore(chat): Cleaning Error Codes + Tests (#8186) 2026-02-06 11:39:36 -08:00
Justin Tahara
e126346a91 fix(agents): Removing Label Dependency (#8189) 2026-02-06 11:03:16 -08:00
Justin Tahara
a96682fa73 fix(ui): Agent Saving with other people files (#8095) 2026-02-02 10:30:46 -08:00
Justin Tahara
3920371d56 feat(desktop): Ensure that UI reflects Light/Dark Toggle (#7684) 2026-02-02 10:30:36 -08:00
Wenxi Onyx
e5a257345c 2nd dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:17:12 -08:00
Wenxi Onyx
a49df511e2 dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:09:41 -08:00
Justin Tahara
d5d2a8a1a6 fix(desktop): Remove Global Shortcuts (#7914) 2026-01-30 13:46:26 -08:00
Justin Tahara
b2f46b264c fix(asana): Workspace Team ID mismatch (#7674) 2026-01-30 13:19:07 -08:00
Jamison Lahman
c6ad363fbd chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 15:52:53 -08:00
Jamison Lahman
e313119f9a fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 14:50:00 -08:00
Wenxi
3a2a542a03 fix: connector details back button should nav back (#7869) 2026-01-27 14:35:15 -08:00
Yuhong Sun
413aeba4a1 fix: Project Creation (#7851) 2026-01-27 14:34:59 -08:00
Wenxi
46028aa2bb fix: user count check (#7811) 2026-01-27 14:34:29 -08:00
Justin Tahara
454943c4a6 fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 14:33:40 -08:00
Justin Tahara
87946266de fix(redis): Adding more TTLs (#7886) 2026-01-27 14:32:14 -08:00
Jamison Lahman
144030c5ca chore(vscode): add non-clean seeded db restore (#7795) 2026-01-26 08:55:19 -08:00
SubashMohan
a557d76041 feat(ui): add new icons and enhance FadeDiv, Modal, Tabs, ExpandableTextDisplay (#7563)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 10:26:09 +00:00
SubashMohan
605e808158 fix(layout): adjust footer margin and prevent page refresh on chatsession drop (#7759) 2026-01-26 04:45:40 +00:00
roshan
8fec88c90d chore(deployment): remove no auth option from setup script (#7784) 2026-01-26 04:42:45 +00:00
Yuhong Sun
e54969a693 fix: LiteLLM Azure models don't stream (#7761) 2026-01-25 07:46:51 +00:00
144 changed files with 7049 additions and 1127 deletions

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

21
.vscode/launch.json vendored
View File

@@ -605,6 +605,27 @@
"order": 0
}
},
{
"name": "Restore seeded database dump",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore seeded database dump (destructive)",
"type": "node",

View File

@@ -42,7 +42,9 @@ RUN apt-get update && \
pkg-config \
gcc \
nano \
vim && \
vim \
libjemalloc2 \
&& \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -124,6 +126,13 @@ ENV PYTHONPATH=/app
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
# in long-running Python processes (API server, Celery workers).
# The soname is architecture-independent; the dynamic linker resolves
# the correct path from standard library directories.
# Placed after all RUN steps so build-time processes are unaffected.
ENV LD_PRELOAD=libjemalloc.so.2
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -12,6 +12,7 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_TTL = FENCE_TTL
logger = setup_logger()
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
# Create the Celery task
celery_app.send_task(

View File

@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
distinct_id=user.email if user else tenant_id,
event="user_message_sent",
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=(
user.email
if user and not getattr(user, "is_anonymous", False)
else tenant_id
),
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)

View File

@@ -153,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -341,6 +352,7 @@ class MilestoneRecordType(str, Enum):
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
USER_MESSAGE_SENT = "user_message_sent"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
@@ -423,6 +435,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.workspace_id = asana_workspace_id.strip()
if asana_project_ids:
project_ids = [
project_id.strip()
for project_id in asana_project_ids.split(",")
if project_id.strip()
]
self.project_ids_to_index = project_ids or None
else:
self.project_ids_to_index = None
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(

View File

@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -6,6 +6,7 @@ import sys
import tempfile
import time
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import cast
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def _convert_to_metadata_value(value: Any) -> str | list[str]:
"""Convert a Salesforce field value to a valid metadata value.
Document metadata expects str | list[str], but Salesforce returns
various types (bool, float, int, etc.). This function ensures all
values are properly converted to strings.
"""
if isinstance(value, list):
return [str(item) for item in value]
return str(value)
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
# # gc.collect()
# return all_types
def _yield_doc_batches(
self,
sf_db: OnyxSalesforceSQLite,
type_to_processed: dict[str, int],
changed_ids_to_type: dict[str, str],
parent_types: set[str],
increment_parents_changed: Callable[[], None],
) -> GenerateDocumentsOutput:
""" """
docs_to_yield: list[Document] = []
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
parent_object.data[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
increment_parents_changed()
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
def _full_sync(
self,
temp_dir: str,
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
docs_to_yield: list[Document] = []
changed_ids_to_type: dict[str, str] = {}
parents_changed = 0
examined_ids = 0
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
f"records={num_records}"
)
# yield an empty list to keep the connector alive
yield docs_to_yield
new_ids = sf_db.update_from_csv(
object_type=object_type,
csv_download_path=csv_path,
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
)
# Step 3 - extract and index docs
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=ctx.parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = (
type_to_processed.get(parent_type, 0) + 1
)
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = parent_object.data[
sf_attribute
]
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
def increment_parents_changed() -> None:
nonlocal parents_changed
parents_changed += 1
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
yield from self._yield_doc_batches(
sf_db,
type_to_processed,
changed_ids_to_type,
ctx.parent_types,
increment_parents_changed,
)
except Exception:
logger.exception("Unexpected exception")
raise
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
if sf_attribute in record:
doc.metadata[canonical_attribute] = record[sf_attribute]
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
record[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
return return_context
def load_from_state(self) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._full_sync(temp_dir)
# nuke the db since we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
os.remove(sqlite_db_path)
return self._full_sync(BASE_DATA_PATH)
# Always use a temp directory for SQLite - the database is rebuilt
# from scratch each time via CSV downloads, so there's no caching benefit
# from persisting it. Using temp dirs also avoids collisions between
# multiple CC pairs and eliminates stale WAL/SHM file issues.
# TODO(evan): make this thing checkpointed and persist/load db from filestore
with tempfile.TemporaryDirectory() as temp_dir:
yield from self._full_sync(temp_dir)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll source will synchronize updated parent objects one by one."""
if start == 0:
# nuke the db if we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
)
os.remove(sqlite_db_path)
return self._delta_sync(BASE_DATA_PATH, start, end)
# Always use a temp directory - see comment in load_from_state()
with tempfile.TemporaryDirectory() as temp_dir:
return self._delta_sync(temp_dir, start, end)
yield from self._delta_sync(temp_dir, start, end)
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
logger = setup_logger()
SQLITE_DISK_IO_ERROR = "disk I/O error"
class OnyxSalesforceSQLite:
"""Notes on context management using 'with self.conn':
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
def apply_schema(self) -> None:
"""Initialize the SQLite database with required tables if they don't exist.
Non-destructive operation.
Non-destructive operation. If a disk I/O error is encountered (often due
to stale WAL/SHM files from a previous crash), this method will attempt
to recover by removing the corrupted files and recreating the database.
"""
try:
self._apply_schema_impl()
except sqlite3.OperationalError as e:
if SQLITE_DISK_IO_ERROR not in str(e):
raise
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
self._recover_from_corruption()
self._apply_schema_impl()
def _recover_from_corruption(self) -> None:
"""Recover from SQLite corruption by removing all database files and reconnecting."""
logger.info(f"Removing corrupted SQLite files: {self.filename}")
# Close existing connection
self.close()
# Remove all SQLite files (main db, WAL, SHM)
remove_sqlite_db_files(self.filename)
# Reconnect - this will create a fresh database
self.connect()
logger.info("SQLite recovery complete, fresh database created")
def _apply_schema_impl(self) -> None:
"""Internal implementation of apply_schema."""
if self._conn is None:
raise RuntimeError("Database connection is closed")

View File

@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
return os.path.join(directory, "salesforce_db.sqlite")
def remove_sqlite_db_files(db_path: str) -> None:
"""Remove SQLite database and all associated files (WAL, SHM).
SQLite in WAL mode creates additional files:
- .sqlite-wal: Write-ahead log
- .sqlite-shm: Shared memory file
If these files become stale (e.g., after a crash), they can cause
'disk I/O error' when trying to open the database. This function
ensures all related files are removed.
"""
files_to_remove = [
db_path,
f"{db_path}-wal",
f"{db_path}-shm",
]
for file_path in files_to_remove:
if os.path.exists(file_path):
os.remove(file_path)
# NOTE: only used with shelves, deprecated at this point
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)

View File

@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
"Persona",
secondary=Persona__PersonaLabel.__table__,
back_populates="labels",
cascade="all, delete-orphan",
single_parent=True,
)

View File

@@ -917,7 +917,9 @@ def upsert_persona(
existing_persona.icon_name = icon_name
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None

View File

@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
@@ -327,6 +329,15 @@ def delete_user_from_db(
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
# Null out ownership on document sets and personas so they're
# preserved for other users instead of being cascade-deleted
db_session.query(DocumentSet).filter(
DocumentSet.user_id == user_to_delete.id
).update({DocumentSet.user_id: None})
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
{Persona.user_id: None}
)
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()

View File

@@ -685,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
def _patch_azure_responses_should_fake_stream() -> None:
"""
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
not in its database. This causes Azure custom model deployments to buffer the entire
response before yielding, resulting in poor time-to-first-token.
Azure's Responses API supports native streaming, so we override this to always use
real streaming (SyncResponsesAPIStreamingIterator).
"""
from litellm.llms.azure.responses.transformation import (
AzureOpenAIResponsesAPIConfig,
)
if (
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
== "_patched_should_fake_stream"
):
return
def _patched_should_fake_stream(
self: Any,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
# Azure Responses API supports native streaming - never fake it
return False
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -694,12 +728,13 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_openai_responses_transform_response()
_patch_azure_responses_should_fake_stream()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,65 +1,270 @@
from mistune import Markdown # type: ignore[import-untyped]
from mistune import Renderer
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
return Markdown(renderer=SlackRenderer()).render(message)
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result.rstrip("\n")
class SlackRenderer(Renderer):
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
def strong(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
lines = text.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
def link(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
if text:
return f"<{escaped_url}|{text}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
return f"<{escaped_url}|{title}>"
return f"<{escaped_url}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
def image(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)
return f"{text}\n\n"

View File

@@ -32,6 +32,7 @@ class RedisConnectorDelete:
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
TASKSET_TTL = FENCE_TTL
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
self.redis.sadd(self.taskset_key, custom_task_id)
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
# Priority on sync's triggered by new indexing should be medium
celery_app.send_task(

View File

@@ -45,6 +45,7 @@ class RedisConnectorPrune:
) # connectorpruning_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
TASKSET_TTL = FENCE_TTL
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
# used to signal the overall workflow is still active
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
# add to the tracking taskset in redis BEFORE creating the celery task.
self.redis.sadd(self.taskset_key, custom_task_id)
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(

View File

@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = PREFIX + "_taskset"
TASKSET_TTL = FENCE_TTL
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,

View File

@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = PREFIX + "_taskset"
TASKSET_TTL = FENCE_TTL
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,

View File

@@ -84,7 +84,8 @@ def patch_document_set(
user=user,
target_group_ids=document_set_update_request.groups,
object_is_public=document_set_update_request.is_public,
object_is_owned_by_user=user and document_set.user_id == user.id,
object_is_owned_by_user=user
and (document_set.user_id is None or document_set.user_id == user.id),
)
try:
update_document_set(
@@ -125,7 +126,8 @@ def delete_document_set(
db_session=db_session,
user=user,
object_is_public=document_set.is_public,
object_is_owned_by_user=user and document_set.user_id == user.id,
object_is_owned_by_user=user
and (document_set.user_id is None or document_set.user_id == user.id),
)
try:

View File

@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
@router.get("/", tags=PUBLIC_API_TAGS)
@router.get("", tags=PUBLIC_API_TAGS)
def get_projects(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),

View File

@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import user_needs_to_be_verified
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
from onyx.configs.constants import AuthType
from onyx.configs.constants import DEV_VERSION_PATTERN
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import STABLE_VERSION_PATTERN
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
async def get_auth_type() -> AuthTypeResponse:
user_count = await get_user_count()
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
# associated with until they auth.
has_users = True
if AUTH_TYPE != AuthType.CLOUD:
user_count = await get_user_count()
has_users = user_count > 0
return AuthTypeResponse(
auth_type=AUTH_TYPE,
requires_verification=user_needs_to_be_verified(),
anonymous_user_enabled=anonymous_user_enabled(),
password_min_length=PASSWORD_MIN_LENGTH,
has_users=user_count > 0,
has_users=has_users,
)

View File

@@ -410,26 +410,20 @@ def list_llm_provider_basics(
all_providers = fetch_existing_llm_providers(db_session)
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
is_admin = user and user.role == UserRole.ADMIN
is_admin = user is not None and user.role == UserRole.ADMIN
accessible_providers = []
for provider in all_providers:
# Include all public providers
if provider.is_public:
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
continue
# Include restricted providers user has access to via groups
if is_admin:
# Admins see all providers
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif provider.groups:
# User must be in at least one of the provider's groups
if user_group_ids.intersection({g.id for g in provider.groups}):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif not provider.personas:
# No restrictions = accessible
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
end_time = datetime.now(timezone.utc)

View File

@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import remove_chat_message_feedback
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -266,7 +267,35 @@ def get_chat_session(
include_deleted=include_deleted,
)
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
try:
# If we failed to get a chat session, try to retrieve the session with
# less restrictive filters in order to identify what exactly mismatched
# so we can bubble up an accurate error code andmessage.
existing_chat_session = get_chat_session_by_id(
chat_session_id=session_id,
user_id=None,
db_session=db_session,
is_shared=False,
include_deleted=True,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
if not include_deleted and existing_chat_session.deleted:
raise HTTPException(status_code=404, detail="Chat session has been deleted")
if is_shared:
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
raise HTTPException(
status_code=403, detail="Chat session is not shared"
)
elif user_id is not None and existing_chat_session.user_id not in (
user_id,
None,
):
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=404, detail="Chat session not found")
# for chat-seeding: if the session is unassigned, assign it now. This is done here
# to avoid another back and forth between FE -> BE before starting the first

View File

@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
# Determine stop reason - check if message indicates user cancelled
stop_reason: str | None = None
if chat_message.message:
if "Generation was stopped" in chat_message.message:
if "generation was stopped" in chat_message.message.lower():
stop_reason = "user_cancelled"
# Add overall stop packet at the end

View File

@@ -573,7 +573,7 @@ mcp==1.25.0
# onyx
mdurl==0.1.2
# via markdown-it-py
mistune==0.8.4
mistune==3.2.0
# via onyx
more-itertools==10.8.0
# via

View File

@@ -298,7 +298,7 @@ numpy==2.4.1
# pandas-stubs
# shapely
# voyageai
onyx-devtools==0.4.0
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via

View File

@@ -0,0 +1,281 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -476,8 +476,8 @@ class ChatSessionManager:
else GENERAL_HEADERS
),
)
# Chat session should return 400 if it doesn't exist
return response.status_code == 400
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@staticmethod
def verify_soft_deleted(

View File

@@ -31,7 +31,7 @@ class ProjectManager:
) -> List[UserProjectSnapshot]:
"""Get all projects for a user via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/",
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
@@ -56,7 +56,7 @@ class ProjectManager:
) -> bool:
"""Verify that a project has been deleted by ensuring it's not in list."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/",
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()

View File

@@ -0,0 +1,185 @@
from uuid import uuid4
import pytest
import requests
from requests import HTTPError
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
@pytest.fixture(scope="module", autouse=True)
def reset_for_module() -> None:
"""Reset all data once before running any tests in this module."""
reset_all()
@pytest.fixture
def second_user(admin_user: DATestUser) -> DATestUser:
# Ensure admin exists so this new user is created with BASIC role.
try:
return UserManager.create(name="second_basic_user")
except HTTPError as e:
response = e.response
if response is None:
raise
if response.status_code not in (400, 409):
raise
try:
payload = response.json()
except ValueError:
raise
detail = payload.get("detail")
if not _is_user_already_exists_detail(detail):
raise
print("Second basic user already exists; logging in instead.")
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("second_basic_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.BASIC,
is_active=True,
)
)
def _is_user_already_exists_detail(detail: object) -> bool:
if isinstance(detail, str):
normalized = detail.lower()
return (
"already exists" in normalized
or "register_user_already_exists" in normalized
)
if isinstance(detail, dict):
code = detail.get("code")
if isinstance(code, str) and code.lower() == "register_user_already_exists":
return True
message = detail.get("message")
if isinstance(message, str) and "already exists" in message.lower():
return True
return False
def _get_chat_session(
chat_session_id: str,
user: DATestUser,
is_shared: bool | None = None,
include_deleted: bool | None = None,
) -> requests.Response:
params: dict[str, str] = {}
if is_shared is not None:
params["is_shared"] = str(is_shared).lower()
if include_deleted is not None:
params["include_deleted"] = str(include_deleted).lower()
return requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
params=params,
headers=user.headers,
cookies=user.cookies,
)
def _set_sharing_status(
chat_session_id: str, sharing_status: str, user: DATestUser
) -> requests.Response:
return requests.patch(
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
json={"sharing_status": sharing_status},
headers=user.headers,
cookies=user.cookies,
)
def test_private_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify private sessions are only accessible by the owner and never via share link."""
# Create a private chat session owned by basic_user.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
# Owner can access the private session normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Share link should be forbidden when the session is private.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 403
# Other users cannot access private sessions directly.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Other users also cannot access private sessions via share link.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 403
def test_public_shared_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify shared sessions are accessible only via share link for non-owners."""
# Create a private session, then mark it public.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
assert response.status_code == 200
# Owner can access normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Owner can also access via share link.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 200
# Non-owner cannot access without share link.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Non-owner can access with share link for public sessions.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 200
def test_deleted_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
# Create and soft-delete a session.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
deletion_success = ChatSessionManager.soft_delete(
chat_session=chat_session, user_performing_action=basic_user
)
assert deletion_success is True
# Deleted sessions are not accessible normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 404
# Owner can fetch deleted session only with include_deleted.
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
assert response.status_code == 200
assert response.json().get("deleted") is True
# Non-owner should be blocked even with include_deleted.
response = _get_chat_session(
str(chat_session.id), second_user, include_deleted=True
)
assert response.status_code == 403
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
"""Verify unknown IDs return 404."""
response = _get_chat_session(str(uuid4()), basic_user)
assert response.status_code == 404

View File

@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
assert fallback_llm.config.model_name == default_provider.default_model_name
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Test that the /llm/provider endpoint correctly excludes non-public providers
with no group/persona restrictions.
This tests the fix for the bug where non-public providers with no restrictions
were incorrectly shown to all users instead of being admin-only.
"""
admin_user, basic_user = users
# Create a public provider (should be visible to all)
public_provider = LLMProviderManager.create(
name="public-provider",
is_public=True,
set_as_default=True,
user_performing_action=admin_user,
)
# Create a non-public provider with no restrictions (should be admin-only)
non_public_provider = LLMProviderManager.create(
name="non-public-unrestricted",
is_public=False,
groups=[],
personas=[],
set_as_default=False,
user_performing_action=admin_user,
)
# Non-admin user calls the /llm/provider endpoint
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
assert public_provider.name in provider_names
# Non-public provider with no restrictions should NOT be visible to non-admin
assert non_public_provider.name not in provider_names
# Admin user should see both providers
admin_response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
assert non_public_provider.name in admin_provider_names
def test_provider_delete_clears_persona_references(reset: None) -> None:
"""Test that deleting a provider automatically clears persona references."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -0,0 +1,65 @@
from uuid import uuid4
import requests
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.persona import PersonaLabelManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
def test_update_persona_with_null_label_ids_preserves_labels(
reset: None, admin_user: DATestUser
) -> None:
persona_label = PersonaLabelManager.create(
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
user_performing_action=admin_user,
)
assert persona_label.id is not None
persona = PersonaManager.create(
label_ids=[persona_label.id],
user_performing_action=admin_user,
)
updated_description = f"{persona.description}-updated"
update_request = PersonaUpsertRequest(
name=persona.name,
description=updated_description,
system_prompt=persona.system_prompt or "",
task_prompt=persona.task_prompt or "",
datetime_aware=persona.datetime_aware,
document_set_ids=persona.document_set_ids,
num_chunks=persona.num_chunks,
is_public=persona.is_public,
recency_bias=persona.recency_bias,
llm_filter_extraction=persona.llm_filter_extraction,
llm_relevance_filter=persona.llm_relevance_filter,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
tool_ids=persona.tool_ids,
users=[],
groups=[],
label_ids=None,
)
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=update_request.model_dump(mode="json", exclude_none=False),
headers=admin_user.headers,
cookies=admin_user.cookies,
)
response.raise_for_status()
fetched = requests.get(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=admin_user.headers,
cookies=admin_user.cookies,
)
fetched.raise_for_status()
fetched_persona = fetched.json()
assert fetched_persona["description"] == updated_description
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
assert persona_label.id in fetched_label_ids

View File

@@ -0,0 +1,50 @@
"""Tests for Asana connector configuration parsing."""
import pytest
from onyx.connectors.asana.connector import AsanaConnector
@pytest.mark.parametrize(
"project_ids,expected",
[
(None, None),
("", None),
(" ", None),
(" 123 ", ["123"]),
(" 123 , , 456 , ", ["123", "456"]),
],
)
def test_asana_connector_project_ids_normalization(
project_ids: str | None, expected: list[str] | None
) -> None:
connector = AsanaConnector(
asana_workspace_id=" 1153293530468850 ",
asana_project_ids=project_ids,
asana_team_id=" 1210918501948021 ",
)
assert connector.workspace_id == "1153293530468850"
assert connector.project_ids_to_index == expected
assert connector.asana_team_id == "1210918501948021"
@pytest.mark.parametrize(
"team_id,expected",
[
(None, None),
("", None),
(" ", None),
(" 1210918501948021 ", "1210918501948021"),
],
)
def test_asana_connector_team_id_normalization(
team_id: str | None, expected: str | None
) -> None:
connector = AsanaConnector(
asana_workspace_id="1153293530468850",
asana_project_ids=None,
asana_team_id=team_id,
)
assert connector.asana_team_id == expected

View File

@@ -0,0 +1,506 @@
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
class TestConvertToMetadataValue:
"""Tests for the _convert_to_metadata_value helper function."""
def test_string_value(self) -> None:
"""String values should be returned as-is."""
assert _convert_to_metadata_value("hello") == "hello"
assert _convert_to_metadata_value("") == ""
def test_boolean_true(self) -> None:
"""Boolean True should be converted to string 'True'."""
assert _convert_to_metadata_value(True) == "True"
def test_boolean_false(self) -> None:
"""Boolean False should be converted to string 'False'."""
assert _convert_to_metadata_value(False) == "False"
def test_integer_value(self) -> None:
"""Integer values should be converted to string."""
assert _convert_to_metadata_value(42) == "42"
assert _convert_to_metadata_value(0) == "0"
assert _convert_to_metadata_value(-100) == "-100"
def test_float_value(self) -> None:
"""Float values should be converted to string."""
assert _convert_to_metadata_value(3.14) == "3.14"
assert _convert_to_metadata_value(0.0) == "0.0"
assert _convert_to_metadata_value(-2.5) == "-2.5"
def test_list_of_strings(self) -> None:
"""List of strings should remain as list of strings."""
result = _convert_to_metadata_value(["a", "b", "c"])
assert result == ["a", "b", "c"]
def test_list_of_mixed_types(self) -> None:
"""List with mixed types should have all items converted to strings."""
result = _convert_to_metadata_value([1, True, 3.14, "text"])
assert result == ["1", "True", "3.14", "text"]
def test_empty_list(self) -> None:
"""Empty list should return empty list."""
assert _convert_to_metadata_value([]) == []
class TestYieldDocBatches:
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
@pytest.fixture
def connector(self) -> SalesforceConnector:
"""Create a SalesforceConnector instance with mocked sf_client."""
connector = SalesforceConnector(
batch_size=10,
requested_objects=["Opportunity"],
)
# Mock the sf_client property
mock_sf_client = MagicMock()
mock_sf_client.sf_instance = "test.salesforce.com"
connector._sf_client = mock_sf_client
return connector
@pytest.fixture
def mock_sf_db(self) -> MagicMock:
"""Create a mock OnyxSalesforceSQLite object."""
return MagicMock()
def _create_salesforce_object(
self,
object_id: str,
object_type: str,
data: dict[str, Any],
) -> SalesforceObject:
"""Helper to create a SalesforceObject with required fields."""
# Ensure required fields are present
data.setdefault(ID_FIELD, object_id)
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
data.setdefault(NAME_FIELD, f"Test {object_type}")
return SalesforceObject(id=object_id, type=object_type, data=data)
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_metadata_type_conversion_for_opportunity(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that Opportunity metadata fields are properly type-converted."""
parent_id = "006bm000006kyDpAAI"
parent_type = "Opportunity"
# Create a parent object with various data types in the fields
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Test Opportunity",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"Account": "Acme Corp", # string - should become "account" metadata
"FiscalQuarter": 2, # int - should be converted to "2"
"FiscalYear": 2024, # int - should be converted to "2024"
"IsClosed": False, # bool - should be converted to "False"
"StageName": "Prospecting", # string
"Type": "New Business", # string
"Amount": 50000.50, # float - should be converted to "50000.50"
"CloseDate": "2024-06-30", # string
"Probability": 75, # int - should be converted to "75"
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
# Setup mock sf_db
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
# Create a mock document that convert_sf_object_to_doc will return
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Test Opportunity",
metadata={},
)
mock_convert.return_value = mock_doc
# Track parent changes
parents_changed = 0
def increment() -> None:
nonlocal parents_changed
parents_changed += 1
# Call _yield_doc_batches
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
increment,
)
)
# Verify we got one batch with one document
assert len(batches) == 1
docs = batches[0]
assert len(docs) == 1
doc = docs[0]
assert isinstance(doc, Document)
# Verify metadata type conversions
# All values should be strings (or list of strings)
assert doc.metadata["object_type"] == "Opportunity"
assert doc.metadata["account"] == "Acme Corp" # string stays string
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
assert doc.metadata["fiscal_year"] == "2024" # int -> str
assert doc.metadata["is_closed"] == "False" # bool -> str
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
assert doc.metadata["type"] == "New Business" # string stays string
assert (
doc.metadata["amount"] == "50000.5"
) # float -> str (Python drops trailing zeros)
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
assert doc.metadata["probability"] == "75" # int -> str
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
# Verify parent was counted
assert parents_changed == 1
assert type_to_processed[parent_type] == 1
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_missing_optional_metadata_fields(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that missing optional metadata fields are not added."""
parent_id = "006bm000006kyDqAAI"
parent_type = "Opportunity"
# Create parent object with only some fields
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Minimal Opportunity",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"StageName": "Closed Won",
# Notably missing: Amount, Probability, FiscalQuarter, etc.
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Minimal Opportunity",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Only present fields should be in metadata
assert "stage_name" in doc.metadata
assert doc.metadata["stage_name"] == "Closed Won"
assert "name" in doc.metadata
assert doc.metadata["name"] == "Minimal Opportunity"
# Missing fields should not be in metadata
assert "amount" not in doc.metadata
assert "probability" not in doc.metadata
assert "fiscal_quarter" not in doc.metadata
assert "fiscal_year" not in doc.metadata
assert "is_closed" not in doc.metadata
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_contact_metadata_fields(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test metadata conversion for Contact object type."""
parent_id = "003bm00000EjHCjAAN"
parent_type = "Contact"
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "John Doe",
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
"Account": "Globex Corp",
"CreatedDate": "2024-01-01T00:00:00.000Z",
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="John Doe",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Verify Contact-specific metadata
assert doc.metadata["object_type"] == "Contact"
assert doc.metadata["account"] == "Globex Corp"
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_no_default_attributes_for_unknown_type(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that unknown object types only get object_type metadata."""
parent_id = "001bm00000fd9Z3AAI"
parent_type = "CustomObject__c"
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Custom Record",
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
"CustomField__c": "custom value",
"NumberField__c": 123,
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Custom Record",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Only object_type should be set for unknown types
assert doc.metadata["object_type"] == "CustomObject__c"
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
assert "CustomField__c" not in doc.metadata
assert "NumberField__c" not in doc.metadata
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_skips_missing_parent_objects(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that missing parent objects are skipped gracefully."""
parent_id = "006bm000006kyDrAAI"
parent_type = "Opportunity"
# get_record returns None for missing object
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = None
mock_sf_db.file_size = 1024
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
parents_changed = 0
def increment() -> None:
nonlocal parents_changed
parents_changed += 1
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
increment,
)
)
# Should yield one empty batch
assert len(batches) == 1
assert len(batches[0]) == 0
# convert_sf_object_to_doc should not have been called
mock_convert.assert_not_called()
# Parents changed should still be 0
assert parents_changed == 0
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_multiple_documents_batching(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that multiple documents are correctly batched."""
# Create 3 parent objects
parent_ids = [
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
]
parent_type = "Opportunity"
parent_objects = [
self._create_salesforce_object(
pid,
parent_type,
{
ID_FIELD: pid,
NAME_FIELD: f"Opportunity {i}",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"IsClosed": i % 2 == 0, # alternating bool values
"Amount": 1000.0 * (i + 1),
},
)
for i, pid in enumerate(parent_ids)
]
# Setup mock to return all three
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
)
mock_sf_db.get_record.side_effect = parent_objects
mock_sf_db.file_size = 1024
# Create mock documents
mock_docs = [
Document(
id=f"SALESFORCE_{pid}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier=f"Opportunity {i}",
metadata={},
)
for i, pid in enumerate(parent_ids)
]
mock_convert.side_effect = mock_docs
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
# With batch_size=10, all 3 docs should be in one batch
assert len(batches) == 1
assert len(batches[0]) == 3
# Verify each document has correct metadata
for i, doc in enumerate(batches[0]):
assert isinstance(doc, Document)
assert doc.metadata["object_type"] == "Opportunity"
assert doc.metadata["is_closed"] == str(i % 2 == 0)
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
assert type_to_processed[parent_type] == 3

View File

@@ -0,0 +1,135 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User__UserGroup
from onyx.db.users import delete_user_from_db
def _mock_user(
user_id: UUID | None = None, email: str = "test@example.com"
) -> MagicMock:
user = MagicMock()
user.id = user_id or uuid4()
user.email = email
user.oauth_accounts = []
return user
def _make_query_chain() -> MagicMock:
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
chain = MagicMock()
chain.filter.return_value = chain
return chain
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_nulls_out_document_set_ownership(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
db_session = MagicMock()
query_chains: dict[type, MagicMock] = {}
def query_side_effect(model: type) -> MagicMock:
if model not in query_chains:
query_chains[model] = _make_query_chain()
return query_chains[model]
db_session.query.side_effect = query_side_effect
delete_user_from_db(user, db_session)
# Verify DocumentSet.user_id is nulled out (update, not delete)
doc_set_chain = query_chains[DocumentSet]
doc_set_chain.filter.assert_called()
doc_set_chain.filter.return_value.update.assert_called_once_with(
{DocumentSet.user_id: None}
)
# Verify Persona.user_id is nulled out (update, not delete)
persona_chain = query_chains[Persona]
persona_chain.filter.assert_called()
persona_chain.filter.return_value.update.assert_called_once_with(
{Persona.user_id: None}
)
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_cleans_up_join_tables(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
db_session = MagicMock()
query_chains: dict[type, MagicMock] = {}
def query_side_effect(model: type) -> MagicMock:
if model not in query_chains:
query_chains[model] = _make_query_chain()
return query_chains[model]
db_session.query.side_effect = query_side_effect
delete_user_from_db(user, db_session)
# Join tables should be deleted (not updated)
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
chain = query_chains[model]
chain.filter.return_value.delete.assert_called_once()
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_commits_and_removes_invited(
_mock_ee: Any, mock_remove_invited: Any
) -> None:
user = _mock_user(email="deleted@example.com")
db_session = MagicMock()
db_session.query.return_value = _make_query_chain()
delete_user_from_db(user, db_session)
db_session.delete.assert_called_once_with(user)
db_session.commit.assert_called_once()
mock_remove_invited.assert_called_once_with("deleted@example.com")
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_deletes_oauth_accounts(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
oauth1 = MagicMock()
oauth2 = MagicMock()
user.oauth_accounts = [oauth1, oauth2]
db_session = MagicMock()
db_session.query.return_value = _make_query_chain()
delete_user_from_db(user, db_session)
db_session.delete.assert_any_call(oauth1)
db_session.delete.assert_any_call(oauth2)

View File

@@ -0,0 +1,205 @@
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
from onyx.onyxbot.slack.formatting import _sanitize_html
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered
def test_slack_style_links_converted_to_clickable_links() -> None:
message = "Visit <https://example.com/page|Example Page> for details."
formatted = format_slack_message(message)
assert "<https://example.com/page|Example Page>" in formatted
assert "&lt;" not in formatted
def test_slack_style_links_preserved_inside_code_blocks() -> None:
message = "```\n<https://example.com|click>\n```"
converted = _convert_slack_links_to_markdown(message)
assert "<https://example.com|click>" in converted
def test_html_tags_stripped_outside_code_blocks() -> None:
message = "Hello<br/>world ```<div>code</div>``` after"
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
assert "<br" not in sanitized
assert "<div>code</div>" in sanitized
def test_format_slack_message_block_spacing() -> None:
message = "Paragraph one.\n\nParagraph two."
formatted = format_slack_message(message)
assert "Paragraph one.\n\nParagraph two." == formatted
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
message = "```python\nprint('hi')\n```"
formatted = format_slack_message(message)
assert formatted.endswith("print('hi')\n```")
def test_format_slack_message_ampersand_not_double_escaped() -> None:
message = 'She said "hello" & goodbye.'
formatted = format_slack_message(message)
assert "&amp;" in formatted
assert "&quot;" not in formatted
# -- Table rendering tests --
def test_table_renders_as_vertical_cards() -> None:
message = (
"| Feature | Status | Owner |\n"
"|---------|--------|-------|\n"
"| Auth | Done | Alice |\n"
"| Search | In Progress | Bob |\n"
)
formatted = format_slack_message(message)
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
# Cards separated by blank line
assert "Owner: Alice\n\n*Search*" in formatted
# No raw pipe-and-dash table syntax
assert "---|" not in formatted
def test_table_single_column() -> None:
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
formatted = format_slack_message(message)
assert "*Alice*" in formatted
assert "*Bob*" in formatted
def test_table_embedded_in_text() -> None:
message = (
"Here are the results:\n\n"
"| Item | Count |\n"
"|------|-------|\n"
"| Apples | 5 |\n"
"\n"
"That's all."
)
formatted = format_slack_message(message)
assert "Here are the results:" in formatted
assert "*Apples*\n • Count: 5" in formatted
assert "That's all." in formatted
def test_table_with_formatted_cells() -> None:
message = (
"| Name | Link |\n"
"|------|------|\n"
"| **Alice** | [profile](https://example.com) |\n"
)
formatted = format_slack_message(message)
# Bold cell should not double-wrap: *Alice* not **Alice**
assert "*Alice*" in formatted
assert "**Alice**" not in formatted
assert "<https://example.com|profile>" in formatted
def test_table_with_alignment_specifiers() -> None:
message = (
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
)
formatted = format_slack_message(message)
assert "*a*\n • Center: b\n • Right: c" in formatted
def test_two_tables_in_same_message_use_independent_headers() -> None:
message = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"| X | Y | Z |\n"
"|---|---|---|\n"
"| p | q | r |\n"
)
formatted = format_slack_message(message)
assert "*1*\n • B: 2" in formatted
assert "*p*\n • Y: q\n • Z: r" in formatted
def test_table_empty_first_column_no_bare_asterisks() -> None:
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
formatted = format_slack_message(message)
# Empty title should not produce "**" (bare asterisks)
assert "**" not in formatted
assert " • Status: Done" in formatted

View File

@@ -0,0 +1,57 @@
from typing import Any
from unittest.mock import Mock
from onyx.configs.constants import MilestoneRecordType
from onyx.utils import telemetry as telemetry_utils
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
fetch_impl = Mock()
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_not_called()
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
monkeypatch: Any,
) -> None:
event_telemetry = Mock()
fetch_impl = Mock(return_value=event_telemetry)
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_called_once_with(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=telemetry_utils.noop_fallback,
)
event_telemetry.assert_called_once_with(
"user@example.com",
MilestoneRecordType.USER_MESSAGE_SENT,
{"origin": "web", "tenant_id": "tenant-1"},
)

View File

@@ -582,29 +582,33 @@ else
fi
# Ask for authentication schema
echo ""
print_info "Which authentication schema would you like to set up?"
echo ""
echo "1) Basic - Username/password authentication"
echo "2) No Auth - Open access (development/testing)"
echo ""
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
echo ""
# echo ""
# print_info "Which authentication schema would you like to set up?"
# echo ""
# echo "1) Basic - Username/password authentication"
# echo "2) No Auth - Open access (development/testing)"
# echo ""
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
# echo ""
case "${AUTH_CHOICE:-1}" in
1)
AUTH_SCHEMA="basic"
print_info "Selected: Basic authentication"
;;
2)
AUTH_SCHEMA="disabled"
print_info "Selected: No authentication"
;;
*)
AUTH_SCHEMA="basic"
print_info "Invalid choice, using basic authentication"
;;
esac
# case "${AUTH_CHOICE:-1}" in
# 1)
# AUTH_SCHEMA="basic"
# print_info "Selected: Basic authentication"
# ;;
# # 2)
# # AUTH_SCHEMA="disabled"
# # print_info "Selected: No authentication"
# # ;;
# *)
# AUTH_SCHEMA="basic"
# print_info "Invalid choice, using basic authentication"
# ;;
# esac
# TODO (jessica): Uncomment this once no auth users still have an account
# Use basic auth by default
AUTH_SCHEMA="basic"
# Create .env file from template
print_info "Creating .env file with your selections..."

3
desktop/.gitignore vendored
View File

@@ -22,3 +22,6 @@ npm-debug.log*
# Local env files
.env
.env.local
# Generated files
src-tauri/gen/schemas/acl-manifests.json

View File

@@ -706,16 +706,6 @@ dependencies = [
"typeid",
]
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
]
[[package]]
name = "fdeflate"
version = "0.3.7"
@@ -993,16 +983,6 @@ dependencies = [
"version_check",
]
[[package]]
name = "gethostname"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
dependencies = [
"rustix",
"windows-link 0.2.1",
]
[[package]]
name = "getrandom"
version = "0.1.16"
@@ -1122,24 +1102,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "global-hotkey"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
dependencies = [
"crossbeam-channel",
"keyboard-types",
"objc2 0.6.3",
"objc2-app-kit 0.3.2",
"once_cell",
"serde",
"thiserror 2.0.17",
"windows-sys 0.59.0",
"x11rb",
"xkeysym",
]
[[package]]
name = "gobject-sys"
version = "0.18.0"
@@ -1713,12 +1675,6 @@ dependencies = [
"libc",
]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "litemap"
version = "0.8.1"
@@ -2248,7 +2204,6 @@ dependencies = [
"serde_json",
"tauri",
"tauri-build",
"tauri-plugin-global-shortcut",
"tauri-plugin-shell",
"tauri-plugin-window-state",
"tokio",
@@ -2878,19 +2833,6 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustversion"
version = "1.0.22"
@@ -3605,21 +3547,6 @@ dependencies = [
"walkdir",
]
[[package]]
name = "tauri-plugin-global-shortcut"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
dependencies = [
"global-hotkey",
"log",
"serde",
"serde_json",
"tauri",
"tauri-plugin",
"thiserror 2.0.17",
]
[[package]]
name = "tauri-plugin-shell"
version = "2.3.3"
@@ -5021,29 +4948,6 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "x11rb"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
dependencies = [
"gethostname",
"rustix",
"x11rb-protocol",
]
[[package]]
name = "x11rb-protocol"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
[[package]]
name = "xkeysym"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
[[package]]
name = "yoke"
version = "0.8.1"

View File

@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
[dependencies]
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
tauri-plugin-shell = "2.0"
tauri-plugin-global-shortcut = "2.0"
tauri-plugin-window-state = "2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

File diff suppressed because one or more lines are too long

View File

@@ -2354,72 +2354,6 @@
"const": "core:window:deny-unminimize",
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
},
{
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
"type": "string",
"const": "global-shortcut:default",
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
},
{
"description": "Enables the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-is-registered",
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
},
{
"description": "Enables the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register",
"markdownDescription": "Enables the register command without any pre-configured scope."
},
{
"description": "Enables the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register-all",
"markdownDescription": "Enables the register_all command without any pre-configured scope."
},
{
"description": "Enables the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister",
"markdownDescription": "Enables the unregister command without any pre-configured scope."
},
{
"description": "Enables the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister-all",
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
},
{
"description": "Denies the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-is-registered",
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
},
{
"description": "Denies the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register",
"markdownDescription": "Denies the register command without any pre-configured scope."
},
{
"description": "Denies the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register-all",
"markdownDescription": "Denies the register_all command without any pre-configured scope."
},
{
"description": "Denies the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister",
"markdownDescription": "Denies the unregister command without any pre-configured scope."
},
{
"description": "Denies the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister-all",
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
},
{
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
"type": "string",

View File

@@ -2354,72 +2354,6 @@
"const": "core:window:deny-unminimize",
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
},
{
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
"type": "string",
"const": "global-shortcut:default",
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
},
{
"description": "Enables the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-is-registered",
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
},
{
"description": "Enables the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register",
"markdownDescription": "Enables the register command without any pre-configured scope."
},
{
"description": "Enables the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register-all",
"markdownDescription": "Enables the register_all command without any pre-configured scope."
},
{
"description": "Enables the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister",
"markdownDescription": "Enables the unregister command without any pre-configured scope."
},
{
"description": "Enables the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister-all",
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
},
{
"description": "Denies the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-is-registered",
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
},
{
"description": "Denies the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register",
"markdownDescription": "Denies the register command without any pre-configured scope."
},
{
"description": "Denies the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register-all",
"markdownDescription": "Denies the register_all command without any pre-configured scope."
},
{
"description": "Denies the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister",
"markdownDescription": "Denies the unregister command without any pre-configured scope."
},
{
"description": "Denies the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister-all",
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
},
{
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
"type": "string",

View File

@@ -20,7 +20,6 @@ use tauri::Wry;
use tauri::{
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
};
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
use url::Url;
#[cfg(target_os = "macos")]
use tokio::time::sleep;
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
window.start_dragging().map_err(|e| e.to_string())
}
// ============================================================================
// Shortcuts Setup
// ============================================================================
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
let app_handle = app.clone();
// Avoid hijacking the system-wide Cmd+R on macOS.
#[cfg(target_os = "macos")]
let shortcuts = [
new_chat,
back,
forward,
new_window_shortcut,
show_app,
open_settings_shortcut,
];
#[cfg(not(target_os = "macos"))]
let shortcuts = [
new_chat,
reload,
back,
forward,
new_window_shortcut,
show_app,
open_settings_shortcut,
];
app.global_shortcut().on_shortcuts(
shortcuts,
move |_app, shortcut, _event| {
if shortcut == &new_chat {
trigger_new_chat(&app_handle);
}
if let Some(window) = app_handle.get_webview_window("main") {
if shortcut == &reload {
let _ = window.eval("window.location.reload()");
} else if shortcut == &back {
let _ = window.eval("window.history.back()");
} else if shortcut == &forward {
let _ = window.eval("window.history.forward()");
} else if shortcut == &open_settings_shortcut {
open_settings(&app_handle);
}
}
if shortcut == &new_window_shortcut {
trigger_new_window(&app_handle);
} else if shortcut == &show_app {
focus_main_window(&app_handle);
}
},
)?;
Ok(())
}
// ============================================================================
// Menu Setup
// ============================================================================
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
TRAY_MENU_OPEN_APP_ID,
"Open Onyx",
true,
Some("CmdOrCtrl+Shift+Space"),
None::<&str>,
)?;
let open_chat = MenuItem::with_id(
app,
@@ -666,7 +598,6 @@ fn main() {
tauri::Builder::default()
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
.plugin(tauri_plugin_window_state::Builder::default().build())
.manage(ConfigState {
config: RwLock::new(config),
@@ -698,11 +629,6 @@ fn main() {
.setup(move |app| {
let app_handle = app.handle();
// Setup global shortcuts
if let Err(e) = setup_shortcuts(&app_handle) {
eprintln!("Failed to setup shortcuts: {}", e);
}
if let Err(e) = setup_app_menu(&app_handle) {
eprintln!("Failed to setup menu: {}", e);
}

View File

@@ -22,6 +22,17 @@
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
}
.dark {
--background-900: #1a1a1a;
--background-800: #262626;
--text-light-05: rgba(255, 255, 255, 0.95);
--text-light-03: rgba(255, 255, 255, 0.6);
--white-10: rgba(255, 255, 255, 0.08);
--white-15: rgba(255, 255, 255, 0.12);
--white-20: rgba(255, 255, 255, 0.15);
--white-30: rgba(255, 255, 255, 0.25);
}
* {
box-sizing: border-box;
margin: 0;
@@ -30,7 +41,11 @@
body {
font-family: var(--font-hanken-grotesk);
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
background: linear-gradient(
135deg,
var(--background-900) 0%,
var(--background-800) 100%
);
min-height: 100vh;
color: var(--text-light-05);
display: flex;
@@ -39,6 +54,9 @@
padding: 20px;
-webkit-user-select: none;
user-select: none;
transition:
background 0.3s ease,
color 0.3s ease;
}
.titlebar {
@@ -69,16 +87,19 @@
}
.settings-panel {
background: linear-gradient(
to bottom,
rgba(255, 255, 255, 0.95),
rgba(245, 245, 245, 0.95)
);
background: var(--background-800);
backdrop-filter: blur(24px);
border-radius: 16px;
border: 1px solid var(--white-10);
overflow: hidden;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition:
background 0.3s ease,
border 0.3s ease;
}
.dark .settings-panel {
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
}
.settings-header {
@@ -93,17 +114,19 @@
width: 40px;
height: 40px;
border-radius: 12px;
background: white;
background: var(--background-900);
display: flex;
align-items: center;
justify-content: center;
overflow: hidden;
transition: background 0.3s ease;
}
.settings-icon svg {
width: 24px;
height: 24px;
color: #000;
color: var(--text-light-05);
transition: color 0.3s ease;
}
.settings-title {
@@ -134,9 +157,10 @@
}
.settings-group {
background: rgba(0, 0, 0, 0.03);
background: var(--background-900);
border-radius: 16px;
padding: 4px;
transition: background 0.3s ease;
}
.setting-row {
@@ -176,7 +200,7 @@
border: 1px solid var(--white-10);
border-radius: 8px;
font-size: 14px;
background: rgba(0, 0, 0, 0.05);
background: var(--background-800);
color: var(--text-light-05);
font-family: var(--font-hanken-grotesk);
transition: all 0.2s;
@@ -186,8 +210,8 @@
.input-field:focus {
outline: none;
border-color: var(--white-30);
background: rgba(0, 0, 0, 0.08);
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
background: var(--background-900);
box-shadow: 0 0 0 2px var(--white-10);
}
.input-field::placeholder {
@@ -231,7 +255,7 @@
left: 0;
right: 0;
bottom: 0;
background-color: rgba(0, 0, 0, 0.15);
background-color: var(--white-15);
transition: 0.3s;
border-radius: 24px;
}
@@ -243,14 +267,18 @@
width: 18px;
left: 3px;
bottom: 3px;
background-color: white;
background-color: var(--background-800);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
transition: 0.3s;
border-radius: 50%;
}
.dark .toggle-slider:before {
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
}
input:checked + .toggle-slider {
background-color: rgba(0, 0, 0, 0.3);
background-color: var(--white-30);
}
input:checked + .toggle-slider:before {
@@ -288,14 +316,15 @@
}
kbd {
background: rgba(0, 0, 0, 0.1);
border: 1px solid var(--white-10);
background: var(--white-10);
border: 1px solid var(--white-15);
border-radius: 4px;
padding: 2px 6px;
font-family: monospace;
font-weight: 500;
color: var(--text-light-05);
font-size: 11px;
transition: all 0.3s ease;
}
</style>
</head>
@@ -372,10 +401,34 @@
const errorMessage = document.getElementById("errorMessage");
const saveBtn = document.getElementById("saveBtn");
// Theme detection based on system preferences
function applySystemTheme() {
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
function updateTheme(e) {
if (e.matches) {
document.documentElement.classList.add("dark");
document.body.classList.add("dark");
} else {
document.documentElement.classList.remove("dark");
document.body.classList.remove("dark");
}
}
// Apply initial theme
updateTheme(darkModeQuery);
// Listen for changes
darkModeQuery.addEventListener("change", updateTheme);
}
function showSettings() {
document.body.classList.add("show-settings");
}
// Apply system theme immediately
applySystemTheme();
// Initialize the app
async function init() {
try {

View File

@@ -113,6 +113,23 @@
document.head.appendChild(style);
}
function updateTitleBarTheme(isDark) {
const titleBar = document.getElementById(TITLEBAR_ID);
if (!titleBar) return;
if (isDark) {
titleBar.style.background =
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
} else {
titleBar.style.background =
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
}
}
function buildTitleBar() {
const titleBar = document.createElement("div");
titleBar.id = TITLEBAR_ID;
@@ -134,6 +151,11 @@
}
});
// Apply initial styles matching current theme
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
// Apply styles matching Onyx design system with translucent glass effect
titleBar.style.cssText = `
position: fixed;
@@ -156,8 +178,12 @@
-webkit-backdrop-filter: blur(18px) saturate(180%);
-webkit-app-region: drag;
padding: 0 12px;
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
`;
// Apply correct theme
updateTitleBarTheme(isDark);
return titleBar;
}
@@ -168,6 +194,11 @@
const existing = document.getElementById(TITLEBAR_ID);
if (existing?.parentElement === document.body) {
// Update theme on existing titlebar
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
updateTitleBarTheme(isDark);
return;
}
@@ -178,6 +209,14 @@
const titleBar = buildTitleBar();
document.body.insertBefore(titleBar, document.body.firstChild);
injectStyles();
// Ensure theme is applied immediately after mount
setTimeout(() => {
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
updateTitleBarTheme(isDark);
}, 0);
}
function syncViewportHeight() {
@@ -194,9 +233,66 @@
}
}
function observeThemeChanges() {
let lastKnownTheme = null;
function checkAndUpdateTheme() {
// Check both html and body for dark class (some apps use body)
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
if (lastKnownTheme !== isDark) {
lastKnownTheme = isDark;
updateTitleBarTheme(isDark);
}
}
// Immediate check on setup
checkAndUpdateTheme();
// Watch for theme changes on the HTML element
const themeObserver = new MutationObserver(() => {
checkAndUpdateTheme();
});
themeObserver.observe(document.documentElement, {
attributes: true,
attributeFilter: ["class"],
});
// Also observe body if it exists
if (document.body) {
const bodyObserver = new MutationObserver(() => {
checkAndUpdateTheme();
});
bodyObserver.observe(document.body, {
attributes: true,
attributeFilter: ["class"],
});
}
// Also check periodically in case classList is manipulated directly
// or the theme loads asynchronously after page load
const intervalId = setInterval(() => {
checkAndUpdateTheme();
}, 300);
// Clean up after 30 seconds once theme should be stable
setTimeout(() => {
clearInterval(intervalId);
// But keep checking every 2 seconds for manual theme changes
setInterval(() => {
checkAndUpdateTheme();
}, 2000);
}, 30000);
}
function init() {
mountTitleBar();
syncViewportHeight();
observeThemeChanges();
window.addEventListener("resize", syncViewportHeight, { passive: true });
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
passive: true,

View File

@@ -119,7 +119,7 @@ backend = [
"shapely==2.0.6",
"stripe==10.12.0",
"urllib3==2.6.3",
"mistune==0.8.4",
"mistune==3.2.0",
"sendgrid==6.12.5",
"exa_py==1.15.4",
"braintrust==0.3.9",
@@ -142,7 +142,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.4.0",
"onyx-devtools==0.6.2",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",

26
uv.lock generated
View File

@@ -3897,11 +3897,11 @@ wheels = [
[[package]]
name = "mistune"
version = "0.8.4"
version = "3.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2d/a4/509f6e7783ddd35482feda27bc7f72e65b5e7dc910eca4ab2164daf9c577/mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", size = 58322, upload-time = "2018-10-11T06:59:27.908Z" }
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/09/ec/4b43dae793655b7d8a25f76119624350b4d65eb663459eb9603d7f1f0345/mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4", size = 16220, upload-time = "2018-10-11T06:59:26.044Z" },
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
]
[[package]]
@@ -4766,7 +4766,7 @@ requires-dist = [
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.25.0" },
{ name = "mistune", marker = "extra == 'backend'", specifier = "==0.8.4" },
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
@@ -4775,7 +4775,7 @@ requires-dist = [
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.4.0" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
{ name = "openai", specifier = "==2.14.0" },
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
@@ -4878,20 +4878,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
[[package]]
name = "onyx-devtools"
version = "0.4.0"
version = "0.6.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fastapi" },
{ name = "openapi-generator-cli" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/d8/f68d15c12d27d4525d10697ac7e2d67d6122fb59ccab219afb2973bc33ad/onyx_devtools-0.4.0-py3-none-any.whl", hash = "sha256:3eb821bce7ec8651d57e937d4d8483e1c2c4bc51df8cbab2dbcc05e3740ec96c", size = 2870841, upload-time = "2026-01-23T04:44:32.206Z" },
{ url = "https://files.pythonhosted.org/packages/28/04/6376342389494b51fd89e554dfdaf0d3809b8d1473bc9b72abd2d7dba21e/onyx_devtools-0.4.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:144e518abad3031ffef189445a69356fca1da2a4fb40c7b8431550133bfc4eef", size = 2890308, upload-time = "2026-01-23T04:44:37.674Z" },
{ url = "https://files.pythonhosted.org/packages/cc/c1/859b32fb3eff7e67179d971ace36313ae64e7fc9a242b45e606138b0041f/onyx_devtools-0.4.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0cc74d561f08a9c894adf8de79855b4fc72eb70e823a75e29db7f625ad366bd7", size = 2696160, upload-time = "2026-01-23T04:44:30.647Z" },
{ url = "https://files.pythonhosted.org/packages/59/1b/f1e3f574e9917779d22e3fcb28f8ac1888c250e7452a523f64a6ab8a1759/onyx_devtools-0.4.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d69de76a97d7f9ff8c473afffbf544a65265645d726f3d70cc12dbbd7e364222", size = 2602134, upload-time = "2026-01-23T04:44:31.716Z" },
{ url = "https://files.pythonhosted.org/packages/79/4a/a5d11640fdc23c9bf0e8617ce13793a587e49a64be2d20badf7e9b045e0a/onyx_devtools-0.4.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:fa84980ce8830e35432831aadc19ff465dbc723605aa80c50e0debc58457b70f", size = 2870864, upload-time = "2026-01-23T04:44:31.5Z" },
{ url = "https://files.pythonhosted.org/packages/fc/9f/6a7e02fbf47bcaea4d02b0ed92bea6e2c09408be7654fb3b57a1ba9863f2/onyx_devtools-0.4.0-py3-none-win_amd64.whl", hash = "sha256:8451efe3e137157696decf8b60a19fb3f0c52ae9f2d9b7c5bc6e667900e7c61e", size = 2953545, upload-time = "2026-01-23T04:44:38.11Z" },
{ url = "https://files.pythonhosted.org/packages/52/42/f7a5b99ade06d215fb99de41181d51a9a984f83afb15afa15ce79ecab635/onyx_devtools-0.4.0-py3-none-win_arm64.whl", hash = "sha256:53a5942c922d7049650e934c43f9c057d046f8d53bc68935ebf7e93baa29afc3", size = 2665984, upload-time = "2026-01-23T04:44:29.399Z" },
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
]
[[package]]

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 14 9"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 14 9"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,7 +1,9 @@
import type { SVGProps } from "react";
import type { IconProps } from "@opal/types";
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 15 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgBranch = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M4.75001 5C5.71651 5 6.50001 4.2165 6.50001 3.25C6.50001 2.2835 5.7165 1.5 4.75 1.5C3.78351 1.5 3.00001 2.2835 3.00001 3.25C3.00001 4.2165 3.78351 5 4.75001 5ZM4.75001 5L4.75001 6.24999M4.75 11C3.7835 11 3 11.7835 3 12.75C3 13.7165 3.7835 14.5 4.75 14.5C5.7165 14.5 6.5 13.7165 6.5 12.75C6.5 11.7835 5.71649 11 4.75 11ZM4.75 11L4.75001 6.24999M10.5 8.74997C10.5 9.71646 11.2835 10.5 12.25 10.5C13.2165 10.5 14 9.71646 14 8.74997C14 7.78347 13.2165 7 12.25 7C11.2835 7 10.5 7.78347 10.5 8.74997ZM10.5 8.74997L7.25001 8.74999C5.8693 8.74999 4.75001 7.6307 4.75001 6.24999"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgBranch;

View File

@@ -0,0 +1,16 @@
import type { IconProps } from "@opal/types";
const SvgCircle = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<circle cx="8" cy="8" r="6" strokeWidth={1.5} />
</svg>
);
export default SvgCircle;

View File

@@ -1,10 +1,12 @@
import React from "react";
import type { IconProps } from "@opal/types";
const SvgClaude = (props: IconProps) => {
const SvgClaude = ({ size, ...props }: IconProps) => {
const clipId = React.useId();
return (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgClipboard = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgDownload = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M14 10V12.6667C14 13.3929 13.3929 14 12.6667 14H3.33333C2.60711 14 2 13.3929 2 12.6667V10M4.66667 6.66667L8 10M8 10L11.3333 6.66667M8 10L8 2"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgDownload;

View File

@@ -24,6 +24,7 @@ export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
export { default as SvgBranch } from "@opal/icons/branch";
export { default as SvgBubbleText } from "@opal/icons/bubble-text";
export { default as SvgCalendar } from "@opal/icons/calendar";
export { default as SvgCheck } from "@opal/icons/check";
@@ -36,6 +37,7 @@ export { default as SvgChevronLeft } from "@opal/icons/chevron-left";
export { default as SvgChevronRight } from "@opal/icons/chevron-right";
export { default as SvgChevronUp } from "@opal/icons/chevron-up";
export { default as SvgChevronUpSmall } from "@opal/icons/chevron-up-small";
export { default as SvgCircle } from "@opal/icons/circle";
export { default as SvgClaude } from "@opal/icons/claude";
export { default as SvgClipboard } from "@opal/icons/clipboard";
export { default as SvgClock } from "@opal/icons/clock";
@@ -46,6 +48,7 @@ export { default as SvgCopy } from "@opal/icons/copy";
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
export { default as SvgCpu } from "@opal/icons/cpu";
export { default as SvgDevKit } from "@opal/icons/dev-kit";
export { default as SvgDownload } from "@opal/icons/download";
export { default as SvgDiscordMono } from "@opal/icons/DiscordMono";
export { default as SvgDownloadCloud } from "@opal/icons/download-cloud";
export { default as SvgEdit } from "@opal/icons/edit";
@@ -135,6 +138,7 @@ export { default as SvgStep3End } from "@opal/icons/step3-end";
export { default as SvgStop } from "@opal/icons/stop";
export { default as SvgStopCircle } from "@opal/icons/stop-circle";
export { default as SvgSun } from "@opal/icons/sun";
export { default as SvgTerminal } from "@opal/icons/terminal";
export { default as SvgTerminalSmall } from "@opal/icons/terminal-small";
export { default as SvgTextLinesSmall } from "@opal/icons/text-lines-small";
export { default as SvgThumbsDown } from "@opal/icons/thumbs-down";

View File

@@ -1,17 +1,11 @@
import type { IconProps } from "@opal/types";
const OnyxLogo = ({
width = 24,
height = 24,
className,
...props
}: IconProps) => (
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
<svg
width={width}
height={height}
width={size}
height={size}
viewBox="0 0 56 56"
xmlns="http://www.w3.org/2000/svg"
className={className}
stroke="currentColor"
{...props}
>
@@ -23,4 +17,4 @@ const OnyxLogo = ({
/>
</svg>
);
export default OnyxLogo;
export default SvgOnyxLogo;

View File

@@ -1,10 +1,12 @@
import React from "react";
import type { IconProps } from "@opal/types";
const SvgOpenAI = (props: IconProps) => {
const SvgOpenAI = ({ size, ...props }: IconProps) => {
const clipId = React.useId();
return (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,22 @@
import type { IconProps } from "@opal/types";
const SvgTerminal = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2.66667 11.3333L6.66667 7.33331L2.66667 3.33331M8.00001 12.6666H13.3333"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgTerminal;

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { IconProps } from "@opal/types";
const SvgUserPlus = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgWallet = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -300,13 +300,7 @@ export default function Page() {
<>
<AdminPageTitle
title="Default Assistant"
icon={
<SvgOnyxLogo
width={32}
height={32}
className="my-auto stroke-text-04"
/>
}
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
/>
<DefaultAssistantConfig />
</>

View File

@@ -31,6 +31,7 @@ import { fetchBedrockModels } from "../utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import Tabs from "@/refresh-components/Tabs";
import { cn } from "@/lib/utils";
export const BEDROCK_PROVIDER_NAME = "bedrock";
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
@@ -135,7 +136,7 @@ function BedrockFormInternals({
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
return (
<Form className={LLM_FORM_CLASS_NAME}>
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
<DisplayNameField disabled={!!existingLlmProvider} />
<SelectorFormField
@@ -176,7 +177,7 @@ function BedrockFormInternals({
</Tabs.Content>
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4 w-full">
<TextFormField
name={FIELD_AWS_ACCESS_KEY_ID}
label="AWS Access Key ID"
@@ -191,7 +192,7 @@ function BedrockFormInternals({
</Tabs.Content>
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4 w-full">
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
label="AWS Bedrock Long-term API Key"

View File

@@ -131,10 +131,15 @@ export function CustomForm({
return;
}
const selectedModelNames = modelConfigurations.map(
(config) => config.name
);
await submitLLMProvider({
providerName: values.provider,
values: {
...values,
selected_model_names: selectedModelNames,
custom_config: customConfigProcessing(
values.custom_config_list
),

View File

@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
interface OllamaFormContentProps {
formikProps: FormikProps<OllamaFormValues>;
existingLlmProvider?: LLMProviderView;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
testError: string;
mutate: () => void;
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
function OllamaFormContent({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
testError,
mutate,
onClose,
isFormValid,
}: OllamaFormContentProps) {
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
existingLlmProvider?.model_configurations || []
);
const [isLoadingModels, setIsLoadingModels] = useState(true);
useEffect(() => {
@@ -70,16 +71,25 @@ function OllamaFormContent({
.then((data) => {
if (data.error) {
console.error("Error fetching models:", data.error);
setAvailableModels([]);
setFetchedModels([]);
return;
}
setAvailableModels(data.models);
setFetchedModels(data.models);
})
.finally(() => {
setIsLoadingModels(false);
});
}
}, [formikProps.values.api_base]);
}, [
formikProps.values.api_base,
existingLlmProvider?.name,
setFetchedModels,
]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
return (
<Form className={LLM_FORM_CLASS_NAME}>
@@ -99,7 +109,7 @@ function OllamaFormContent({
/>
<DisplayModels
modelConfigurations={availableModels}
modelConfigurations={currentModels}
formikProps={formikProps}
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
isLoading={isLoadingModels}
@@ -125,6 +135,8 @@ export function OllamaForm({
existingLlmProvider,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
return (
<ProviderFormEntrypointWrapper
providerName="Ollama"
@@ -189,7 +201,10 @@ export function OllamaForm({
providerName: OLLAMA_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations,
modelConfigurations:
fetchedModels.length > 0
? fetchedModels
: modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
@@ -205,6 +220,8 @@ export function OllamaForm({
<OllamaFormContent
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
testError={testError}
mutate={mutate}

View File

@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
<SvgArrowExchange className="size-3 text-text-04" />
</div>
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
<SvgOnyxLogo
width={24}
height={24}
className="text-text-04 shrink-0"
/>
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
</div>
</div>
);

View File

@@ -1168,7 +1168,7 @@ export default function Page() {
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo width={16} height={16} />
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
@@ -1381,7 +1381,7 @@ export default function Page() {
} logo`,
fallback:
selectedContentProviderType === "onyx_web_crawler" ? (
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
<SvgOnyxLogo size={24} className="text-text-05" />
) : undefined,
size: 24,
containerSize: 28,

View File

@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
/>
)}
<BackButton
behaviorOverride={() => router.push("/admin/indexing/status")}
/>
<BackButton />
<div
className="flex
items-center

View File

@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
import { useAgents } from "@/hooks/useAgents";
import { ChatPopup } from "@/app/chat/components/ChatPopup";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
import { useUser } from "@/components/user/UserProvider";
import NoAssistantModal from "@/components/modals/NoAssistantModal";
import TextView from "@/components/chat/TextView";
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
const retrievalEnabled = useMemo(() => {
if (liveAssistant) {
return liveAssistant.tools.some(
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
);
return personaIncludesRetrieval(liveAssistant);
}
return false;
}, [liveAssistant]);

View File

@@ -860,6 +860,7 @@ export function useChatController({
stopReason: stopReason,
packets: packets,
packetsVersion: packetsVersion,
packetCount: packets.length,
},
],
// Pass the latest map state
@@ -886,6 +887,7 @@ export function useChatController({
toolCall: null,
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
packets: [],
packetCount: 0,
},
{
nodeId: initialAssistantNode.nodeId,
@@ -895,6 +897,7 @@ export function useChatController({
toolCall: null,
parentNodeId: initialUserNode.nodeId,
packets: [],
packetCount: 0,
stackTrace: stackTrace,
errorCode: errorCode,
isRetryable: isRetryable,

View File

@@ -141,6 +141,7 @@ export interface Message {
packets: Packet[];
// Version counter for efficient memo comparison (increments with each packet)
packetsVersion?: number;
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
// cached values for easy access
documents?: OnyxDocument[] | null;

View File

@@ -11,6 +11,7 @@ import { CitationMap } from "../../interfaces";
export enum RenderType {
HIGHLIGHT = "highlight",
FULL = "full",
COMPACT = "compact",
}
export interface FullChatState {
@@ -35,6 +36,9 @@ export interface RendererResult {
// used for things that should just show text w/o an icon or header
// e.g. ReasoningRenderer
expandedText?: JSX.Element;
// Whether this renderer supports compact mode (collapse button shown only when true)
supportsCompact?: boolean;
}
export type MessageRenderer<
@@ -48,5 +52,7 @@ export type MessageRenderer<
animate: boolean;
stopPacketSeen: boolean;
stopReason?: StopReason;
/** Whether this is the last step in the timeline (for connector line decisions) */
isLastStep?: boolean;
children: (result: RendererResult) => JSX.Element;
}>;

View File

@@ -68,10 +68,11 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
const icon = FiTool;
if (renderType === RenderType.HIGHLIGHT) {
if (renderType === RenderType.COMPACT) {
return children({
icon,
status: status,
supportsCompact: true,
content: (
<div className="text-sm text-muted-foreground">
{isRunning && `${toolName} running...`}
@@ -84,6 +85,7 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
return children({
icon,
status,
supportsCompact: true,
content: (
<div className="flex flex-col gap-3">
{/* File responses */}

View File

@@ -72,6 +72,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: "Generating images...",
supportsCompact: false,
content: (
<div className="flex flex-col">
<div>
@@ -89,6 +90,7 @@ export const ImageToolRenderer: MessageRenderer<
status: `Generated ${images.length} image${
images.length !== 1 ? "s" : ""
}`,
supportsCompact: false,
content: (
<div className="flex flex-col my-1">
{images.length > 0 ? (
@@ -122,6 +124,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: status,
supportsCompact: false,
content: <div></div>,
});
}
@@ -131,6 +134,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: "Generating image...",
supportsCompact: false,
content: (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<div className="flex gap-0.5">
@@ -154,6 +158,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: "Image generation failed",
supportsCompact: false,
content: (
<div className="text-sm text-red-600 dark:text-red-400">
Image generation failed
@@ -166,6 +171,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: `Generated ${images.length} image${images.length > 1 ? "s" : ""}`,
supportsCompact: false,
content: (
<div className="text-sm text-muted-foreground">
Generated {images.length} image
@@ -178,6 +184,7 @@ export const ImageToolRenderer: MessageRenderer<
return children({
icon: FiImage,
status: "Image generation",
supportsCompact: false,
content: (
<div className="text-sm text-muted-foreground">Image generation</div>
),

View File

@@ -0,0 +1,441 @@
"use client";
import React, { FunctionComponent, useMemo, useCallback } from "react";
import { StopReason } from "@/app/chat/services/streamingModels";
import { FullChatState } from "../interfaces";
import { TurnGroup, TransformedStep } from "./transformers";
import { cn } from "@/lib/utils";
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
import { SvgCheckCircle, SvgStopCircle } from "@opal/icons";
import { IconProps } from "@opal/types";
import {
TimelineRendererComponent,
TimelineRendererResult,
} from "./TimelineRendererComponent";
import Text from "@/refresh-components/texts/Text";
import { ParallelTimelineTabs } from "./ParallelTimelineTabs";
import { StepContainer } from "./StepContainer";
import {
useTimelineExpansion,
useTimelineMetrics,
useTimelineHeader,
} from "@/app/chat/message/messageComponents/timeline/hooks";
import {
isResearchAgentPackets,
stepSupportsCompact,
} from "@/app/chat/message/messageComponents/timeline/packetHelpers";
import {
StreamingHeader,
CollapsedHeader,
ExpandedHeader,
StoppedHeader,
ParallelStreamingHeader,
} from "@/app/chat/message/messageComponents/timeline/headers";
// =============================================================================
// TimelineStep Component - Memoized to prevent re-renders
// =============================================================================
interface TimelineStepProps {
step: TransformedStep;
chatState: FullChatState;
stopPacketSeen: boolean;
stopReason?: StopReason;
isLastStep: boolean;
isFirstStep: boolean;
isSingleStep: boolean;
}
//will be removed on cleanup
const noopCallback = () => {};
const TimelineStep = React.memo(function TimelineStep({
step,
chatState,
stopPacketSeen,
stopReason,
isLastStep,
isFirstStep,
isSingleStep,
}: TimelineStepProps) {
// Stable render callback - doesn't need to change between renders
const renderStep = useCallback(
({
icon,
status,
content,
isExpanded,
onToggle,
isLastStep: rendererIsLastStep,
supportsCompact,
}: TimelineRendererResult) =>
isResearchAgentPackets(step.packets) ? (
content
) : (
<StepContainer
stepIcon={icon as FunctionComponent<IconProps> | undefined}
header={status}
isExpanded={isExpanded}
onToggle={onToggle}
collapsible={true}
supportsCompact={supportsCompact}
isLastStep={rendererIsLastStep}
isFirstStep={isFirstStep}
hideHeader={isSingleStep}
>
{content}
</StepContainer>
),
[step.packets, isFirstStep, isSingleStep]
);
return (
<TimelineRendererComponent
packets={step.packets}
chatState={chatState}
onComplete={noopCallback}
animate={!stopPacketSeen}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
defaultExpanded={true}
isLastStep={isLastStep}
>
{renderStep}
</TimelineRendererComponent>
);
});
// =============================================================================
// Main Component
// =============================================================================
export interface AgentTimelineProps {
/** Turn groups from usePacketProcessor */
turnGroups: TurnGroup[];
/** Chat state for rendering content */
chatState: FullChatState;
/** Whether the stop packet has been seen */
stopPacketSeen?: boolean;
/** Reason for stopping (if stopped) */
stopReason?: StopReason;
/** Whether final answer is coming (affects last connector) */
finalAnswerComing?: boolean;
/** Whether there is display content after timeline */
hasDisplayContent?: boolean;
/** Content to render after timeline (final message + toolbar) - slot pattern */
children?: React.ReactNode;
/** Whether the timeline is collapsible */
collapsible?: boolean;
/** Title of the button to toggle the timeline */
buttonTitle?: string;
/** Additional class names */
className?: string;
/** Test ID for e2e testing */
"data-testid"?: string;
/** Unique tool names (pre-computed for performance) */
uniqueToolNames?: string[];
}
export function AgentTimeline({
turnGroups,
chatState,
stopPacketSeen = false,
stopReason,
finalAnswerComing = false,
hasDisplayContent = false,
collapsible = true,
buttonTitle,
className,
"data-testid": testId,
uniqueToolNames = [],
}: AgentTimelineProps) {
// Header text and state flags
const { headerText, hasPackets, userStopped } = useTimelineHeader(
turnGroups,
stopReason
);
// Memoized metrics derived from turn groups
const {
totalSteps,
isSingleStep,
uniqueTools,
lastTurnGroup,
lastStep,
lastStepIsResearchAgent,
lastStepSupportsCompact,
} = useTimelineMetrics(turnGroups, uniqueToolNames, userStopped);
// Expansion state management
const { isExpanded, handleToggle, parallelActiveTab, setParallelActiveTab } =
useTimelineExpansion(stopPacketSeen, lastTurnGroup);
// Stable callbacks to avoid creating new functions on every render
const noopComplete = useCallback(() => {}, []);
const renderContentOnly = useCallback(
({ content }: TimelineRendererResult) => content,
[]
);
// Parallel step analysis for collapsed streaming view
const parallelActiveStep = useMemo(() => {
if (!lastTurnGroup?.isParallel) return null;
return (
lastTurnGroup.steps.find((s) => s.key === parallelActiveTab) ??
lastTurnGroup.steps[0]
);
}, [lastTurnGroup, parallelActiveTab]);
const parallelActiveStepSupportsCompact = useMemo(() => {
if (!parallelActiveStep) return false;
return (
stepSupportsCompact(parallelActiveStep.packets) &&
!isResearchAgentPackets(parallelActiveStep.packets)
);
}, [parallelActiveStep]);
// Collapsed streaming: show compact content below header
const showCollapsedCompact =
!stopPacketSeen &&
!isExpanded &&
lastStep &&
!lastTurnGroup?.isParallel &&
!lastStepIsResearchAgent &&
lastStepSupportsCompact;
// Parallel tabs in header only when collapsed (expanded view has tabs in content)
const showParallelTabs =
!stopPacketSeen &&
!isExpanded &&
lastTurnGroup?.isParallel &&
lastTurnGroup.steps.length > 0;
// Collapsed parallel compact content
const showCollapsedParallel =
showParallelTabs && !isExpanded && parallelActiveStepSupportsCompact;
// Done indicator conditions
const showDoneIndicator =
stopPacketSeen && isExpanded && !userStopped && !lastStepIsResearchAgent;
// Header selection based on state
const renderHeader = () => {
if (!stopPacketSeen) {
if (showParallelTabs && lastTurnGroup) {
return (
<ParallelStreamingHeader
steps={lastTurnGroup.steps}
activeTab={parallelActiveTab}
onTabChange={setParallelActiveTab}
collapsible={collapsible}
isExpanded={isExpanded}
onToggle={handleToggle}
/>
);
}
return (
<StreamingHeader
headerText={headerText}
collapsible={collapsible}
buttonTitle={buttonTitle}
isExpanded={isExpanded}
onToggle={handleToggle}
/>
);
}
if (userStopped) {
return (
<StoppedHeader
totalSteps={totalSteps}
collapsible={collapsible}
isExpanded={isExpanded}
onToggle={handleToggle}
/>
);
}
if (!isExpanded) {
return (
<CollapsedHeader
uniqueTools={uniqueTools}
totalSteps={totalSteps}
collapsible={collapsible}
onToggle={handleToggle}
/>
);
}
return <ExpandedHeader collapsible={collapsible} onToggle={handleToggle} />;
};
// Empty state: no packets, still streaming
if (!hasPackets && !hasDisplayContent) {
return (
<div className={cn("flex flex-col", className)}>
<div className="flex w-full h-9">
<div className="flex justify-center items-center size-9">
<AgentAvatar agent={chatState.assistant} size={24} />
</div>
<div className="flex w-full h-full items-center px-2">
<Text
as="p"
mainUiAction
text03
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
>
{headerText}
</Text>
</div>
</div>
</div>
);
}
// Display content only (no timeline steps)
if (hasDisplayContent && !hasPackets) {
return (
<div className={cn("flex flex-col", className)}>
<div className="flex w-full h-9">
<div className="flex justify-center items-center size-9">
<AgentAvatar agent={chatState.assistant} size={24} />
</div>
</div>
</div>
);
}
return (
<div className={cn("flex flex-col", className)}>
{/* Header row */}
<div className="flex w-full h-9">
<div className="flex justify-center items-center size-9">
<AgentAvatar agent={chatState.assistant} size={24} />
</div>
<div
className={cn(
"flex w-full h-full items-center justify-between px-2",
(!stopPacketSeen || userStopped || isExpanded) &&
"bg-background-tint-00 rounded-t-12",
!isExpanded &&
!showCollapsedCompact &&
!showCollapsedParallel &&
"rounded-b-12"
)}
>
{renderHeader()}
</div>
</div>
{/* Collapsed streaming view - single step compact mode */}
{showCollapsedCompact && lastStep && (
<div className="flex w-full">
<div className="w-9" />
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
<TimelineRendererComponent
key={`${lastStep.key}-compact`}
packets={lastStep.packets}
chatState={chatState}
onComplete={noopComplete}
animate={true}
stopPacketSeen={false}
stopReason={stopReason}
defaultExpanded={false}
isLastStep={true}
>
{renderContentOnly}
</TimelineRendererComponent>
</div>
</div>
)}
{/* Collapsed streaming view - parallel tools compact mode */}
{showCollapsedParallel && parallelActiveStep && (
<div className="flex w-full">
<div className="w-9" />
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
<TimelineRendererComponent
key={`${parallelActiveStep.key}-compact`}
packets={parallelActiveStep.packets}
chatState={chatState}
onComplete={noopComplete}
animate={true}
stopPacketSeen={false}
stopReason={stopReason}
defaultExpanded={false}
isLastStep={true}
>
{renderContentOnly}
</TimelineRendererComponent>
</div>
</div>
)}
{/* Expanded timeline view */}
{isExpanded && (
<div className="w-full">
{turnGroups.map((turnGroup, turnIdx) =>
turnGroup.isParallel ? (
<ParallelTimelineTabs
key={turnGroup.turnIndex}
turnGroup={turnGroup}
chatState={chatState}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
isLastTurnGroup={turnIdx === turnGroups.length - 1}
/>
) : (
turnGroup.steps.map((step, stepIdx) => {
const stepIsLast =
turnIdx === turnGroups.length - 1 &&
stepIdx === turnGroup.steps.length - 1 &&
!showDoneIndicator &&
!userStopped;
const stepIsFirst = turnIdx === 0 && stepIdx === 0;
return (
<TimelineStep
key={step.key}
step={step}
chatState={chatState}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
isLastStep={stepIsLast}
isFirstStep={stepIsFirst}
isSingleStep={isSingleStep}
/>
);
})
)
)}
{/* Done indicator */}
{stopPacketSeen && isExpanded && !userStopped && (
<StepContainer
stepIcon={SvgCheckCircle}
header="Done"
isLastStep={true}
isFirstStep={false}
>
{null}
</StepContainer>
)}
{/* Stopped indicator */}
{stopPacketSeen && isExpanded && userStopped && (
<StepContainer
stepIcon={SvgStopCircle}
header="Stopped"
isLastStep={true}
isFirstStep={false}
>
{null}
</StepContainer>
)}
</div>
)}
</div>
);
}
export default AgentTimeline;

View File

@@ -0,0 +1,130 @@
"use client";
import React, {
useState,
useMemo,
useCallback,
FunctionComponent,
} from "react";
import { StopReason } from "@/app/chat/services/streamingModels";
import { FullChatState } from "../interfaces";
import { TurnGroup } from "./transformers";
import { getToolName, getToolIcon } from "../toolDisplayHelpers";
import {
TimelineRendererComponent,
TimelineRendererResult,
} from "./TimelineRendererComponent";
import Tabs from "@/refresh-components/Tabs";
import { SvgBranch } from "@opal/icons";
import { StepContainer } from "./StepContainer";
import { isResearchAgentPackets } from "@/app/chat/message/messageComponents/timeline/packetHelpers";
import { IconProps } from "@/components/icons/icons";
export interface ParallelTimelineTabsProps {
/** Turn group containing parallel steps */
turnGroup: TurnGroup;
/** Chat state for rendering content */
chatState: FullChatState;
/** Whether the stop packet has been seen */
stopPacketSeen: boolean;
/** Reason for stopping (if stopped) */
stopReason?: StopReason;
/** Whether this is the last turn group (affects connector line) */
isLastTurnGroup: boolean;
/** Additional class names */
className?: string;
}
export function ParallelTimelineTabs({
turnGroup,
chatState,
stopPacketSeen,
stopReason,
isLastTurnGroup,
className,
}: ParallelTimelineTabsProps) {
const [activeTab, setActiveTab] = useState(turnGroup.steps[0]?.key ?? "");
// Find the active step based on selected tab
const activeStep = useMemo(
() => turnGroup.steps.find((step) => step.key === activeTab),
[turnGroup.steps, activeTab]
);
//will be removed on cleanup
// Stable callbacks to avoid creating new functions on every render
const noopComplete = useCallback(() => {}, []);
const renderTabContent = useCallback(
({
icon,
status,
content,
isExpanded,
onToggle,
isLastStep,
}: TimelineRendererResult) =>
isResearchAgentPackets(activeStep?.packets ?? []) ? (
content
) : (
<StepContainer
stepIcon={icon as FunctionComponent<IconProps> | undefined}
header={status}
isExpanded={isExpanded}
onToggle={onToggle}
collapsible={true}
isLastStep={isLastStep}
isFirstStep={false}
>
{content}
</StepContainer>
),
[activeStep?.packets]
);
return (
<Tabs value={activeTab} onValueChange={setActiveTab}>
<div className="flex flex-col w-full gap-1">
<div className="flex w-full">
{/* Left column: Icon + connector line */}
<div className="flex flex-col items-center w-9 pt-2">
<div className="size-4 flex items-center justify-center stroke-text-02">
<SvgBranch className="w-4 h-4" />
</div>
{/* Connector line */}
<div className="w-px flex-1 bg-border-01" />
</div>
{/* Right column: Tabs */}
<div className="flex-1">
<Tabs.List variant="pill">
{turnGroup.steps.map((step) => (
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
<span className="flex items-center gap-1.5">
{getToolIcon(step.packets)}
{getToolName(step.packets)}
</span>
</Tabs.Trigger>
))}
</Tabs.List>
</div>
</div>
<div className="w-full">
<TimelineRendererComponent
key={activeTab}
packets={activeStep?.packets ?? []}
chatState={chatState}
onComplete={noopComplete}
animate={!stopPacketSeen}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
defaultExpanded={true}
isLastStep={isLastTurnGroup}
>
{renderTabContent}
</TimelineRendererComponent>
</div>
</div>
</Tabs>
);
}
export default ParallelTimelineTabs;

View File

@@ -0,0 +1,108 @@
import React, { FunctionComponent } from "react";
import { cn } from "@/lib/utils";
import { SvgFold, SvgExpand } from "@opal/icons";
import Button from "@/refresh-components/buttons/Button";
import IconButton from "@/refresh-components/buttons/IconButton";
import { IconProps } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
export interface StepContainerProps {
/** Main content */
children?: React.ReactNode;
/** Step icon component */
stepIcon?: FunctionComponent<IconProps>;
/** Header left slot */
header?: React.ReactNode;
/** Button title for toggle */
buttonTitle?: string;
/** Controlled expanded state */
isExpanded?: boolean;
/** Toggle callback */
onToggle?: () => void;
/** Whether collapse control is shown */
collapsible?: boolean;
/** Collapse button shown only when renderer supports compact mode */
supportsCompact?: boolean;
/** Additional class names */
className?: string;
/** Last step (no bottom connector) */
isLastStep?: boolean;
/** First step (top padding instead of connector) */
isFirstStep?: boolean;
/** Hide header (single-step timelines) */
hideHeader?: boolean;
}
/** Visual wrapper for timeline steps - icon, connector line, header, and content */
export function StepContainer({
children,
stepIcon: StepIconComponent,
header,
buttonTitle,
isExpanded = true,
onToggle,
collapsible = true,
supportsCompact = false,
isLastStep = false,
isFirstStep = false,
className,
hideHeader = false,
}: StepContainerProps) {
const showCollapseControls = collapsible && supportsCompact && onToggle;
return (
<div className={cn("flex w-full", className)}>
<div
className={cn("flex flex-col items-center w-9", isFirstStep && "pt-2")}
>
{/* Icon */}
{!hideHeader && StepIconComponent && (
<div className="py-1">
<StepIconComponent className="size-4 stroke-text-02" />
</div>
)}
{/* Connector line */}
{!isLastStep && <div className="w-px flex-1 bg-border-01" />}
</div>
<div
className={cn(
"w-full bg-background-tint-00",
isLastStep && "rounded-b-12"
)}
>
{!hideHeader && (
<div className="flex items-center justify-between px-2">
{header && (
<Text as="p" mainUiMuted text03>
{header}
</Text>
)}
{showCollapseControls &&
(buttonTitle ? (
<Button
tertiary
onClick={onToggle}
rightIcon={isExpanded ? SvgFold : SvgExpand}
>
{buttonTitle}
</Button>
) : (
<IconButton
tertiary
onClick={onToggle}
icon={isExpanded ? SvgFold : SvgExpand}
/>
))}
</div>
)}
<div className="px-2 pb-2">{children}</div>
</div>
</div>
);
}
export default StepContainer;

View File

@@ -0,0 +1,116 @@
"use client";
import React, { useState, JSX } from "react";
import { Packet, StopReason } from "@/app/chat/services/streamingModels";
import { FullChatState, RenderType, RendererResult } from "../interfaces";
import { findRenderer } from "../renderMessageComponent";
/** Extended result that includes collapse state */
export interface TimelineRendererResult extends RendererResult {
/** Current expanded state */
isExpanded: boolean;
/** Toggle callback */
onToggle: () => void;
/** Current render type */
renderType: RenderType;
/** Whether this is the last step (passed through from props) */
isLastStep: boolean;
}
export interface TimelineRendererComponentProps {
/** Packets to render */
packets: Packet[];
/** Chat state for rendering */
chatState: FullChatState;
/** Completion callback */
onComplete: () => void;
/** Whether to animate streaming */
animate: boolean;
/** Whether stop packet has been seen */
stopPacketSeen: boolean;
/** Reason for stopping */
stopReason?: StopReason;
/** Initial expanded state */
defaultExpanded?: boolean;
/** Whether this is the last step in the timeline (for connector line decisions) */
isLastStep?: boolean;
/** Children render function - receives extended result with collapse state */
children: (result: TimelineRendererResult) => JSX.Element;
}
// Custom comparison function to prevent unnecessary re-renders
// Only re-render if meaningful changes occur
function arePropsEqual(
prev: TimelineRendererComponentProps,
next: TimelineRendererComponentProps
): boolean {
return (
prev.packets.length === next.packets.length &&
prev.stopPacketSeen === next.stopPacketSeen &&
prev.stopReason === next.stopReason &&
prev.animate === next.animate &&
prev.isLastStep === next.isLastStep &&
prev.defaultExpanded === next.defaultExpanded
// Skipping chatState (memoized upstream)
);
}
export const TimelineRendererComponent = React.memo(
function TimelineRendererComponent({
packets,
chatState,
onComplete,
animate,
stopPacketSeen,
stopReason,
defaultExpanded = true,
isLastStep,
children,
}: TimelineRendererComponentProps) {
const [isExpanded, setIsExpanded] = useState(defaultExpanded);
const handleToggle = () => setIsExpanded((prev) => !prev);
const RendererFn = findRenderer({ packets });
const renderType = isExpanded ? RenderType.FULL : RenderType.COMPACT;
if (!RendererFn) {
return children({
icon: null,
status: null,
content: <></>,
supportsCompact: false,
isExpanded,
onToggle: handleToggle,
renderType,
isLastStep: isLastStep ?? true,
});
}
return (
<RendererFn
packets={packets as any}
state={chatState}
onComplete={onComplete}
animate={animate}
renderType={renderType}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
isLastStep={isLastStep}
>
{({ icon, status, content, expandedText, supportsCompact }) =>
children({
icon,
status,
content,
expandedText,
supportsCompact,
isExpanded,
onToggle: handleToggle,
renderType,
isLastStep: isLastStep ?? true,
})
}
</RendererFn>
);
},
arePropsEqual
);

View File

@@ -0,0 +1,49 @@
import React from "react";
import { SvgExpand } from "@opal/icons";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
import type { UniqueTool } from "@/app/chat/message/messageComponents/timeline/hooks";
export interface CollapsedHeaderProps {
uniqueTools: UniqueTool[];
totalSteps: number;
collapsible: boolean;
onToggle: () => void;
}
/** Header when completed + collapsed - tools summary + step count */
export const CollapsedHeader = React.memo(function CollapsedHeader({
uniqueTools,
totalSteps,
collapsible,
onToggle,
}: CollapsedHeaderProps) {
return (
<>
<div className="flex items-center gap-2">
{uniqueTools.map((tool) => (
<div
key={tool.key}
className="inline-flex items-center gap-1 rounded-08 p-1 bg-background-tint-02"
>
{tool.icon}
<Text as="span" secondaryBody text04>
{tool.name}
</Text>
</div>
))}
</div>
{collapsible && (
<Button
tertiary
onClick={onToggle}
rightIcon={SvgExpand}
aria-label="Expand timeline"
aria-expanded={false}
>
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
</Button>
)}
</>
);
});

View File

@@ -0,0 +1,32 @@
import React from "react";
import { SvgFold } from "@opal/icons";
import IconButton from "@/refresh-components/buttons/IconButton";
import Text from "@/refresh-components/texts/Text";
export interface ExpandedHeaderProps {
collapsible: boolean;
onToggle: () => void;
}
/** Header when completed + expanded */
export const ExpandedHeader = React.memo(function ExpandedHeader({
collapsible,
onToggle,
}: ExpandedHeaderProps) {
return (
<>
<Text as="p" mainUiAction text03>
Thought for some time
</Text>
{collapsible && (
<IconButton
tertiary
onClick={onToggle}
icon={SvgFold}
aria-label="Collapse timeline"
aria-expanded={true}
/>
)}
</>
);
});

View File

@@ -0,0 +1,53 @@
import React from "react";
import { SvgFold, SvgExpand } from "@opal/icons";
import IconButton from "@/refresh-components/buttons/IconButton";
import Tabs from "@/refresh-components/Tabs";
import { TurnGroup } from "../transformers";
import { getToolIcon, getToolName } from "../../toolDisplayHelpers";
export interface ParallelStreamingHeaderProps {
steps: TurnGroup["steps"];
activeTab: string;
onTabChange: (tab: string) => void;
collapsible: boolean;
isExpanded: boolean;
onToggle: () => void;
}
/** Header during streaming with parallel tools - tabs only */
export const ParallelStreamingHeader = React.memo(
function ParallelStreamingHeader({
steps,
activeTab,
onTabChange,
collapsible,
isExpanded,
onToggle,
}: ParallelStreamingHeaderProps) {
return (
<Tabs value={activeTab} onValueChange={onTabChange}>
<div className="flex items-center justify-between w-full gap-2">
<Tabs.List variant="pill">
{steps.map((step) => (
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
<span className="flex items-center gap-1.5">
{getToolIcon(step.packets)}
{getToolName(step.packets)}
</span>
</Tabs.Trigger>
))}
</Tabs.List>
{collapsible && (
<IconButton
tertiary
onClick={onToggle}
icon={isExpanded ? SvgFold : SvgExpand}
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
aria-expanded={isExpanded}
/>
)}
</div>
</Tabs>
);
}
);

View File

@@ -0,0 +1,38 @@
import React from "react";
import { SvgFold, SvgExpand } from "@opal/icons";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
export interface StoppedHeaderProps {
totalSteps: number;
collapsible: boolean;
isExpanded: boolean;
onToggle: () => void;
}
/** Header when user stopped/cancelled */
export const StoppedHeader = React.memo(function StoppedHeader({
totalSteps,
collapsible,
isExpanded,
onToggle,
}: StoppedHeaderProps) {
return (
<>
<Text as="p" mainUiAction text03>
Stopped Thinking
</Text>
{collapsible && (
<Button
tertiary
onClick={onToggle}
rightIcon={isExpanded ? SvgFold : SvgExpand}
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
aria-expanded={isExpanded}
>
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
</Button>
)}
</>
);
});

View File

@@ -0,0 +1,54 @@
import React from "react";
import { SvgFold, SvgExpand } from "@opal/icons";
import Button from "@/refresh-components/buttons/Button";
import IconButton from "@/refresh-components/buttons/IconButton";
import Text from "@/refresh-components/texts/Text";
export interface StreamingHeaderProps {
headerText: string;
collapsible: boolean;
buttonTitle?: string;
isExpanded: boolean;
onToggle: () => void;
}
/** Header during streaming - shimmer text with current activity */
export const StreamingHeader = React.memo(function StreamingHeader({
headerText,
collapsible,
buttonTitle,
isExpanded,
onToggle,
}: StreamingHeaderProps) {
return (
<>
<Text
as="p"
mainUiAction
text03
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
>
{headerText}
</Text>
{collapsible &&
(buttonTitle ? (
<Button
tertiary
onClick={onToggle}
rightIcon={isExpanded ? SvgFold : SvgExpand}
aria-expanded={isExpanded}
>
{buttonTitle}
</Button>
) : (
<IconButton
tertiary
onClick={onToggle}
icon={isExpanded ? SvgFold : SvgExpand}
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
aria-expanded={isExpanded}
/>
))}
</>
);
});

View File

@@ -0,0 +1,14 @@
export { StreamingHeader } from "./StreamingHeader";
export type { StreamingHeaderProps } from "./StreamingHeader";
export { CollapsedHeader } from "./CollapsedHeader";
export type { CollapsedHeaderProps } from "./CollapsedHeader";
export { ExpandedHeader } from "./ExpandedHeader";
export type { ExpandedHeaderProps } from "./ExpandedHeader";
export { StoppedHeader } from "./StoppedHeader";
export type { StoppedHeaderProps } from "./StoppedHeader";
export { ParallelStreamingHeader } from "./ParallelStreamingHeader";
export type { ParallelStreamingHeaderProps } from "./ParallelStreamingHeader";

View File

@@ -0,0 +1,11 @@
export { useTimelineExpansion } from "./useTimelineExpansion";
export type { TimelineExpansionState } from "./useTimelineExpansion";
export { useTimelineMetrics } from "./useTimelineMetrics";
export type { TimelineMetrics, UniqueTool } from "./useTimelineMetrics";
export { usePacketProcessor } from "./usePacketProcessor";
export type { UsePacketProcessorResult } from "./usePacketProcessor";
export { useTimelineHeader } from "./useTimelineHeader";
export type { TimelineHeaderResult } from "./useTimelineHeader";

View File

@@ -0,0 +1,439 @@
import {
Packet,
PacketType,
StreamingCitation,
StopReason,
CitationInfo,
SearchToolDocumentsDelta,
FetchToolDocuments,
TopLevelBranching,
Stop,
SearchToolStart,
CustomToolStart,
} from "@/app/chat/services/streamingModels";
import { CitationMap } from "@/app/chat/interfaces";
import { OnyxDocument } from "@/lib/search/interfaces";
import {
isActualToolCallPacket,
isToolPacket,
isDisplayPacket,
} from "@/app/chat/services/packetUtils";
import { parseToolKey } from "@/app/chat/message/messageComponents/toolDisplayHelpers";
// Re-export parseToolKey for consumers that import from this module
export { parseToolKey };
// ============================================================================
// Types
// ============================================================================
export interface ProcessorState {
nodeId: number;
lastProcessedIndex: number;
// Citations
citations: StreamingCitation[];
seenCitationDocIds: Set<string>;
citationMap: CitationMap;
// Documents
documentMap: Map<string, OnyxDocument>;
// Packet grouping
groupedPacketsMap: Map<string, Packet[]>;
seenGroupKeys: Set<string>;
groupKeysWithSectionEnd: Set<string>;
expectedBranches: Map<number, number>;
// Pre-categorized groups (populated during packet processing)
toolGroupKeys: Set<string>;
displayGroupKeys: Set<string>;
// Unique tool names tracking (populated during packet processing)
uniqueToolNames: Set<string>;
// Streaming status
finalAnswerComing: boolean;
stopPacketSeen: boolean;
stopReason: StopReason | undefined;
// Result arrays (built at end of processPackets)
toolGroups: GroupedPacket[];
potentialDisplayGroups: GroupedPacket[];
uniqueToolNamesArray: string[];
}
export interface GroupedPacket {
turn_index: number;
tab_index: number;
packets: Packet[];
}
// ============================================================================
// State Creation
// ============================================================================
export function createInitialState(nodeId: number): ProcessorState {
return {
nodeId,
lastProcessedIndex: 0,
citations: [],
seenCitationDocIds: new Set(),
citationMap: {},
documentMap: new Map(),
groupedPacketsMap: new Map(),
seenGroupKeys: new Set(),
groupKeysWithSectionEnd: new Set(),
expectedBranches: new Map(),
toolGroupKeys: new Set(),
displayGroupKeys: new Set(),
uniqueToolNames: new Set(),
finalAnswerComing: false,
stopPacketSeen: false,
stopReason: undefined,
toolGroups: [],
potentialDisplayGroups: [],
uniqueToolNamesArray: [],
};
}
// ============================================================================
// Helper Functions
// ============================================================================
function getGroupKey(packet: Packet): string {
const turnIndex = packet.placement.turn_index;
const tabIndex = packet.placement.tab_index ?? 0;
return `${turnIndex}-${tabIndex}`;
}
function injectSectionEnd(state: ProcessorState, groupKey: string): void {
if (state.groupKeysWithSectionEnd.has(groupKey)) {
return; // Already has SECTION_END
}
const { turn_index, tab_index } = parseToolKey(groupKey);
const syntheticPacket: Packet = {
placement: { turn_index, tab_index },
obj: { type: PacketType.SECTION_END },
};
const existingGroup = state.groupedPacketsMap.get(groupKey);
if (existingGroup) {
existingGroup.push(syntheticPacket);
}
state.groupKeysWithSectionEnd.add(groupKey);
}
/**
* Content packet types that indicate a group has meaningful content to display
*/
const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
PacketType.MESSAGE_START,
PacketType.SEARCH_TOOL_START,
PacketType.IMAGE_GENERATION_TOOL_START,
PacketType.PYTHON_TOOL_START,
PacketType.CUSTOM_TOOL_START,
PacketType.FETCH_TOOL_START,
PacketType.REASONING_START,
PacketType.DEEP_RESEARCH_PLAN_START,
PacketType.RESEARCH_AGENT_START,
]);
function hasContentPackets(packets: Packet[]): boolean {
return packets.some((packet) =>
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
);
}
/**
* Extract tool name from a packet for unique tool tracking.
* Returns null for non-tool packets.
*/
function getToolNameFromPacket(packet: Packet): string | null {
switch (packet.obj.type) {
case PacketType.SEARCH_TOOL_START: {
const searchPacket = packet.obj as SearchToolStart;
return searchPacket.is_internet_search ? "Web Search" : "Internal Search";
}
case PacketType.PYTHON_TOOL_START:
return "Code Interpreter";
case PacketType.FETCH_TOOL_START:
return "Open URLs";
case PacketType.CUSTOM_TOOL_START: {
const customPacket = packet.obj as CustomToolStart;
return customPacket.tool_name || "Custom Tool";
}
case PacketType.IMAGE_GENERATION_TOOL_START:
return "Generate Image";
case PacketType.DEEP_RESEARCH_PLAN_START:
return "Generate plan";
case PacketType.RESEARCH_AGENT_START:
return "Research agent";
case PacketType.REASONING_START:
return "Thinking";
default:
return null;
}
}
/**
* Packet types that indicate final answer content is coming
*/
const FINAL_ANSWER_PACKET_TYPES_SET = new Set<PacketType>([
PacketType.MESSAGE_START,
PacketType.MESSAGE_DELTA,
PacketType.IMAGE_GENERATION_TOOL_START,
PacketType.IMAGE_GENERATION_TOOL_DELTA,
PacketType.PYTHON_TOOL_START,
PacketType.PYTHON_TOOL_DELTA,
]);
// ============================================================================
// Packet Handlers
// ============================================================================
function handleTopLevelBranching(state: ProcessorState, packet: Packet): void {
const branchingPacket = packet.obj as TopLevelBranching;
state.expectedBranches.set(
packet.placement.turn_index,
branchingPacket.num_parallel_branches
);
}
function handleTurnTransition(state: ProcessorState, packet: Packet): void {
const currentTurnIndex = packet.placement.turn_index;
// Get all previous turn indices from seen group keys
const previousTurnIndices = new Set(
Array.from(state.seenGroupKeys).map((key) => parseToolKey(key).turn_index)
);
const isNewTurnIndex = !previousTurnIndices.has(currentTurnIndex);
// If we see a new turn_index (not just tab_index), inject SECTION_END for previous groups
if (isNewTurnIndex && state.seenGroupKeys.size > 0) {
state.seenGroupKeys.forEach((prevGroupKey) => {
if (!state.groupKeysWithSectionEnd.has(prevGroupKey)) {
injectSectionEnd(state, prevGroupKey);
}
});
}
}
function handleCitationPacket(state: ProcessorState, packet: Packet): void {
if (packet.obj.type !== PacketType.CITATION_INFO) {
return;
}
const citationInfo = packet.obj as CitationInfo;
// Add to citation map immediately for rendering
state.citationMap[citationInfo.citation_number] = citationInfo.document_id;
// Also add to citations array for CitedSourcesToggle (deduplicated)
if (!state.seenCitationDocIds.has(citationInfo.document_id)) {
state.seenCitationDocIds.add(citationInfo.document_id);
state.citations.push({
citation_num: citationInfo.citation_number,
document_id: citationInfo.document_id,
});
}
}
function handleDocumentPacket(state: ProcessorState, packet: Packet): void {
if (packet.obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA) {
const docDelta = packet.obj as SearchToolDocumentsDelta;
if (docDelta.documents) {
for (const doc of docDelta.documents) {
if (doc.document_id) {
state.documentMap.set(doc.document_id, doc);
}
}
}
} else if (packet.obj.type === PacketType.FETCH_TOOL_DOCUMENTS) {
const fetchDocuments = packet.obj as FetchToolDocuments;
if (fetchDocuments.documents) {
for (const doc of fetchDocuments.documents) {
if (doc.document_id) {
state.documentMap.set(doc.document_id, doc);
}
}
}
}
}
function handleStreamingStatusPacket(
state: ProcessorState,
packet: Packet
): void {
// Check if final answer is coming
if (FINAL_ANSWER_PACKET_TYPES_SET.has(packet.obj.type as PacketType)) {
state.finalAnswerComing = true;
}
}
function handleStopPacket(state: ProcessorState, packet: Packet): void {
if (packet.obj.type !== PacketType.STOP || state.stopPacketSeen) {
return;
}
state.stopPacketSeen = true;
// Extract and store the stop reason
const stopPacket = packet.obj as Stop;
state.stopReason = stopPacket.stop_reason;
// Inject SECTION_END for all group keys that don't have one
state.seenGroupKeys.forEach((groupKey) => {
if (!state.groupKeysWithSectionEnd.has(groupKey)) {
injectSectionEnd(state, groupKey);
}
});
}
function handleToolAfterMessagePacket(
state: ProcessorState,
packet: Packet
): void {
// Handles case where we get a Message packet from Claude, and then tool
// calling packets. We use isActualToolCallPacket instead of isToolPacket
// to exclude reasoning packets - reasoning is just the model thinking,
// not an actual tool call that would produce new content.
if (
state.finalAnswerComing &&
!state.stopPacketSeen &&
isActualToolCallPacket(packet)
) {
state.finalAnswerComing = false;
}
}
function addPacketToGroup(
state: ProcessorState,
packet: Packet,
groupKey: string
): void {
const existingGroup = state.groupedPacketsMap.get(groupKey);
if (existingGroup) {
existingGroup.push(packet);
} else {
state.groupedPacketsMap.set(groupKey, [packet]);
}
}
// ============================================================================
// Main Processing Function
// ============================================================================
function processPacket(state: ProcessorState, packet: Packet): void {
if (!packet) return;
// Handle TopLevelBranching packets - these tell us how many parallel branches to expect
if (packet.obj.type === PacketType.TOP_LEVEL_BRANCHING) {
handleTopLevelBranching(state, packet);
// Don't add this packet to any group, it's just metadata
return;
}
// Handle turn transitions (inject SECTION_END for previous groups)
handleTurnTransition(state, packet);
// Track group key
const groupKey = getGroupKey(packet);
state.seenGroupKeys.add(groupKey);
// Track SECTION_END and ERROR packets (both indicate completion)
if (
packet.obj.type === PacketType.SECTION_END ||
packet.obj.type === PacketType.ERROR
) {
state.groupKeysWithSectionEnd.add(groupKey);
}
// Check if this is the first packet in the group (before adding)
const existingGroup = state.groupedPacketsMap.get(groupKey);
const isFirstPacket = !existingGroup;
// Add packet to group
addPacketToGroup(state, packet, groupKey);
// Categorize on first packet of each group
if (isFirstPacket) {
if (isToolPacket(packet, false)) {
state.toolGroupKeys.add(groupKey);
// Track unique tool name
const toolName = getToolNameFromPacket(packet);
if (toolName) {
state.uniqueToolNames.add(toolName);
}
}
if (isDisplayPacket(packet)) {
state.displayGroupKeys.add(groupKey);
}
}
// Handle specific packet types
handleCitationPacket(state, packet);
handleDocumentPacket(state, packet);
handleStreamingStatusPacket(state, packet);
handleStopPacket(state, packet);
handleToolAfterMessagePacket(state, packet);
}
export function processPackets(
state: ProcessorState,
rawPackets: Packet[]
): ProcessorState {
// Handle reset (packets array shrunk - upstream replaced with shorter list)
if (state.lastProcessedIndex > rawPackets.length) {
state = createInitialState(state.nodeId);
}
// Process only new packets
for (let i = state.lastProcessedIndex; i < rawPackets.length; i++) {
const packet = rawPackets[i];
if (packet) {
processPacket(state, packet);
}
}
state.lastProcessedIndex = rawPackets.length;
// Build result arrays after processing
state.toolGroups = buildGroupsFromKeys(state, state.toolGroupKeys);
state.potentialDisplayGroups = buildGroupsFromKeys(
state,
state.displayGroupKeys
);
state.uniqueToolNamesArray = Array.from(state.uniqueToolNames);
return state;
}
/**
* Build GroupedPacket array from a set of group keys.
* Filters to only include groups with meaningful content and sorts by turn/tab index.
*/
function buildGroupsFromKeys(
state: ProcessorState,
keys: Set<string>
): GroupedPacket[] {
return Array.from(keys)
.map((key) => {
const { turn_index, tab_index } = parseToolKey(key);
const packets = state.groupedPacketsMap.get(key);
// Spread to create new array reference - ensures React detects changes for re-renders
return packets ? { turn_index, tab_index, packets: [...packets] } : null;
})
.filter(
(g): g is GroupedPacket => g !== null && hasContentPackets(g.packets)
)
.sort((a, b) => {
if (a.turn_index !== b.turn_index) {
return a.turn_index - b.turn_index;
}
return a.tab_index - b.tab_index;
});
}

View File

@@ -0,0 +1,153 @@
import { useRef, useState, useMemo, useCallback } from "react";
import {
Packet,
StreamingCitation,
StopReason,
} from "@/app/chat/services/streamingModels";
import { CitationMap } from "@/app/chat/interfaces";
import { OnyxDocument } from "@/lib/search/interfaces";
import {
ProcessorState,
GroupedPacket,
createInitialState,
processPackets,
} from "@/app/chat/message/messageComponents/timeline/hooks/packetProcessor";
import {
transformPacketGroups,
groupStepsByTurn,
TurnGroup,
} from "@/app/chat/message/messageComponents/timeline/transformers";
export interface UsePacketProcessorResult {
// Data
toolGroups: GroupedPacket[];
displayGroups: GroupedPacket[];
toolTurnGroups: TurnGroup[];
citations: StreamingCitation[];
citationMap: CitationMap;
documentMap: Map<string, OnyxDocument>;
// Status (derived from packets)
stopPacketSeen: boolean;
stopReason: StopReason | undefined;
hasSteps: boolean;
expectedBranchesPerTurn: Map<number, number>;
uniqueToolNames: string[];
// Completion: stopPacketSeen && renderComplete
isComplete: boolean;
// Callbacks
onRenderComplete: () => void;
markAllToolsDisplayed: () => void;
}
/**
* Hook for processing streaming packets in AgentMessage.
*
* Architecture:
* - Processor state in ref: incremental processing, synchronous, no double render
* - Only true UI state: renderComplete (set by callback), forceShowAnswer (override)
* - Everything else derived from packets
*
* Key insight: finalAnswerComing and stopPacketSeen are DERIVED from packets,
* not independent state. Only renderComplete needs useState.
*/
export function usePacketProcessor(
rawPackets: Packet[],
nodeId: number
): UsePacketProcessorResult {
// Processor in ref: incremental, synchronous, no double render
const stateRef = useRef<ProcessorState>(createInitialState(nodeId));
// Only TRUE UI state: "has renderer finished?"
const [renderComplete, setRenderComplete] = useState(false);
// Optional override to force showing answer
const [forceShowAnswer, setForceShowAnswer] = useState(false);
// Reset on nodeId change
if (stateRef.current.nodeId !== nodeId) {
stateRef.current = createInitialState(nodeId);
setRenderComplete(false);
setForceShowAnswer(false);
}
// Track for transition detection
const prevLastProcessed = stateRef.current.lastProcessedIndex;
const prevFinalAnswerComing = stateRef.current.finalAnswerComing;
// Detect stream reset (packets shrunk)
if (prevLastProcessed > rawPackets.length) {
stateRef.current = createInitialState(nodeId);
setRenderComplete(false);
setForceShowAnswer(false);
}
// Process packets synchronously (incremental) - only if new packets arrived
if (rawPackets.length > stateRef.current.lastProcessedIndex) {
stateRef.current = processPackets(stateRef.current, rawPackets);
}
// Reset renderComplete on tool-after-message transition
if (prevFinalAnswerComing && !stateRef.current.finalAnswerComing) {
setRenderComplete(false);
}
// Access state directly (result arrays are built in processPackets)
const state = stateRef.current;
// Derive displayGroups (not state!)
const effectiveFinalAnswerComing = state.finalAnswerComing || forceShowAnswer;
const displayGroups = useMemo(() => {
if (effectiveFinalAnswerComing || state.toolGroups.length === 0) {
return state.potentialDisplayGroups;
}
return [];
}, [
effectiveFinalAnswerComing,
state.toolGroups.length,
state.potentialDisplayGroups,
]);
// Transform toolGroups to timeline format
const toolTurnGroups = useMemo(() => {
const allSteps = transformPacketGroups(state.toolGroups);
return groupStepsByTurn(allSteps);
}, [state.toolGroups]);
// Callback reads from ref: always current value, no ref needed in component
const onRenderComplete = useCallback(() => {
if (stateRef.current.finalAnswerComing) {
setRenderComplete(true);
}
}, []);
const markAllToolsDisplayed = useCallback(() => {
setForceShowAnswer(true);
}, []);
return {
// Data
toolGroups: state.toolGroups,
displayGroups,
toolTurnGroups,
citations: state.citations,
citationMap: state.citationMap,
documentMap: state.documentMap,
// Status (derived from packets)
stopPacketSeen: state.stopPacketSeen,
stopReason: state.stopReason,
hasSteps: toolTurnGroups.length > 0,
expectedBranchesPerTurn: state.expectedBranches,
uniqueToolNames: state.uniqueToolNamesArray,
// Completion: stopPacketSeen && renderComplete
isComplete: state.stopPacketSeen && renderComplete,
// Callbacks
onRenderComplete,
markAllToolsDisplayed,
};
}

View File

@@ -0,0 +1,50 @@
import { useState, useEffect, useCallback } from "react";
import { TurnGroup } from "../transformers";
export interface TimelineExpansionState {
isExpanded: boolean;
handleToggle: () => void;
parallelActiveTab: string;
setParallelActiveTab: (tab: string) => void;
}
/**
* Manages expansion state for the timeline.
* Auto-collapses when streaming completes and syncs parallel tab selection.
*/
export function useTimelineExpansion(
stopPacketSeen: boolean,
lastTurnGroup: TurnGroup | undefined
): TimelineExpansionState {
const [isExpanded, setIsExpanded] = useState(!stopPacketSeen);
const [parallelActiveTab, setParallelActiveTab] = useState<string>("");
const handleToggle = useCallback(() => {
setIsExpanded((prev) => !prev);
}, []);
// Auto-collapse when streaming completes
useEffect(() => {
if (stopPacketSeen) {
setIsExpanded(false);
}
}, [stopPacketSeen]);
// Sync active tab when parallel turn group changes
useEffect(() => {
if (lastTurnGroup?.isParallel && lastTurnGroup.steps.length > 0) {
const validTabs = lastTurnGroup.steps.map((s) => s.key);
const firstStep = lastTurnGroup.steps[0];
if (firstStep && !validTabs.includes(parallelActiveTab)) {
setParallelActiveTab(firstStep.key);
}
}
}, [lastTurnGroup, parallelActiveTab]);
return {
isExpanded,
handleToggle,
parallelActiveTab,
setParallelActiveTab,
};
}

View File

@@ -0,0 +1,97 @@
import { useMemo } from "react";
import { TurnGroup } from "../transformers";
import {
PacketType,
SearchToolPacket,
StopReason,
CustomToolStart,
} from "@/app/chat/services/streamingModels";
import { constructCurrentSearchState } from "@/app/chat/message/messageComponents/timeline/renderers/search/searchStateUtils";
export interface TimelineHeaderResult {
headerText: string;
hasPackets: boolean;
userStopped: boolean;
}
/**
* Hook that determines timeline header state based on current activity.
* Returns header text, whether there are packets, and whether user stopped.
*/
export function useTimelineHeader(
turnGroups: TurnGroup[],
stopReason?: StopReason
): TimelineHeaderResult {
return useMemo(() => {
const hasPackets = turnGroups.length > 0;
const userStopped = stopReason === StopReason.USER_CANCELLED;
if (!hasPackets) {
return { headerText: "Thinking...", hasPackets, userStopped };
}
// Get the last (current) turn group
const currentTurn = turnGroups[turnGroups.length - 1];
if (!currentTurn) {
return { headerText: "Thinking...", hasPackets, userStopped };
}
const currentStep = currentTurn.steps[0];
if (!currentStep?.packets?.length) {
return { headerText: "Thinking...", hasPackets, userStopped };
}
const firstPacket = currentStep.packets[0];
if (!firstPacket) {
return { headerText: "Thinking...", hasPackets, userStopped };
}
const packetType = firstPacket.obj.type;
// Determine header based on packet type
if (packetType === PacketType.SEARCH_TOOL_START) {
const searchState = constructCurrentSearchState(
currentStep.packets as SearchToolPacket[]
);
const headerText = searchState.isInternetSearch
? "Searching web"
: "Searching internal documents";
return { headerText, hasPackets, userStopped };
}
if (packetType === PacketType.FETCH_TOOL_START) {
return { headerText: "Opening URLs", hasPackets, userStopped };
}
if (packetType === PacketType.PYTHON_TOOL_START) {
return { headerText: "Executing code", hasPackets, userStopped };
}
if (packetType === PacketType.IMAGE_GENERATION_TOOL_START) {
return { headerText: "Generating images", hasPackets, userStopped };
}
if (packetType === PacketType.CUSTOM_TOOL_START) {
const toolName = (firstPacket.obj as CustomToolStart).tool_name;
return {
headerText: toolName ? `Executing ${toolName}` : "Executing tool",
hasPackets,
userStopped,
};
}
if (packetType === PacketType.REASONING_START) {
return { headerText: "Thinking", hasPackets, userStopped };
}
if (packetType === PacketType.DEEP_RESEARCH_PLAN_START) {
return { headerText: "Generating plan", hasPackets, userStopped };
}
if (packetType === PacketType.RESEARCH_AGENT_START) {
return { headerText: "Researching", hasPackets, userStopped };
}
return { headerText: "Thinking...", hasPackets, userStopped };
}, [turnGroups, stopReason]);
}

Some files were not shown because too many files have changed in this diff Show More