mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-27 10:32:41 +00:00
Compare commits
40 Commits
action-blo
...
v2.11.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afc163d7e5 | ||
|
|
0f214ca190 | ||
|
|
422ca91edc | ||
|
|
fe287eebb6 | ||
|
|
3f8ef8b465 | ||
|
|
ed46504a1a | ||
|
|
7a24b34516 | ||
|
|
7a7ffa9051 | ||
|
|
3053ab518c | ||
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 |
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
21
.vscode/launch.json
vendored
21
.vscode/launch.json
vendored
@@ -605,6 +605,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -42,7 +42,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -124,6 +126,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -153,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -341,6 +352,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -423,6 +435,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -685,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -694,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,65 +1,270 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -32,6 +32,7 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -45,6 +45,7 @@ class RedisConnectorPrune:
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
|
||||
@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -84,7 +84,8 @@ def patch_document_set(
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
@@ -125,7 +126,8 @@ def delete_document_set(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
object_is_public=document_set.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
@router.get("/", tags=PUBLIC_API_TAGS)
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
|
||||
|
||||
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
|
||||
async def get_auth_type() -> AuthTypeResponse:
|
||||
user_count = await get_user_count()
|
||||
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
|
||||
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
|
||||
# associated with until they auth.
|
||||
has_users = True
|
||||
if AUTH_TYPE != AuthType.CLOUD:
|
||||
user_count = await get_user_count()
|
||||
has_users = user_count > 0
|
||||
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
password_min_length=PASSWORD_MIN_LENGTH,
|
||||
has_users=user_count > 0,
|
||||
has_users=has_users,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
|
||||
# Determine stop reason - check if message indicates user cancelled
|
||||
stop_reason: str | None = None
|
||||
if chat_message.message:
|
||||
if "Generation was stopped" in chat_message.message:
|
||||
if "generation was stopped" in chat_message.message.lower():
|
||||
stop_reason = "user_cancelled"
|
||||
|
||||
# Add overall stop packet at the end
|
||||
|
||||
@@ -573,7 +573,7 @@ mcp==1.25.0
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==0.8.4
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
|
||||
@@ -298,7 +298,7 @@ numpy==2.4.1
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.4.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProjectManager:
|
||||
) -> List[UserProjectSnapshot]:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -56,7 +56,7 @@ class ProjectManager:
|
||||
) -> bool:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
|
||||
class TestConvertToMetadataValue:
|
||||
"""Tests for the _convert_to_metadata_value helper function."""
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""String values should be returned as-is."""
|
||||
assert _convert_to_metadata_value("hello") == "hello"
|
||||
assert _convert_to_metadata_value("") == ""
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
"""Boolean True should be converted to string 'True'."""
|
||||
assert _convert_to_metadata_value(True) == "True"
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
"""Boolean False should be converted to string 'False'."""
|
||||
assert _convert_to_metadata_value(False) == "False"
|
||||
|
||||
def test_integer_value(self) -> None:
|
||||
"""Integer values should be converted to string."""
|
||||
assert _convert_to_metadata_value(42) == "42"
|
||||
assert _convert_to_metadata_value(0) == "0"
|
||||
assert _convert_to_metadata_value(-100) == "-100"
|
||||
|
||||
def test_float_value(self) -> None:
|
||||
"""Float values should be converted to string."""
|
||||
assert _convert_to_metadata_value(3.14) == "3.14"
|
||||
assert _convert_to_metadata_value(0.0) == "0.0"
|
||||
assert _convert_to_metadata_value(-2.5) == "-2.5"
|
||||
|
||||
def test_list_of_strings(self) -> None:
|
||||
"""List of strings should remain as list of strings."""
|
||||
result = _convert_to_metadata_value(["a", "b", "c"])
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_list_of_mixed_types(self) -> None:
|
||||
"""List with mixed types should have all items converted to strings."""
|
||||
result = _convert_to_metadata_value([1, True, 3.14, "text"])
|
||||
assert result == ["1", "True", "3.14", "text"]
|
||||
|
||||
def test_empty_list(self) -> None:
|
||||
"""Empty list should return empty list."""
|
||||
assert _convert_to_metadata_value([]) == []
|
||||
|
||||
|
||||
class TestYieldDocBatches:
|
||||
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
|
||||
|
||||
@pytest.fixture
|
||||
def connector(self) -> SalesforceConnector:
|
||||
"""Create a SalesforceConnector instance with mocked sf_client."""
|
||||
connector = SalesforceConnector(
|
||||
batch_size=10,
|
||||
requested_objects=["Opportunity"],
|
||||
)
|
||||
# Mock the sf_client property
|
||||
mock_sf_client = MagicMock()
|
||||
mock_sf_client.sf_instance = "test.salesforce.com"
|
||||
connector._sf_client = mock_sf_client
|
||||
return connector
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sf_db(self) -> MagicMock:
|
||||
"""Create a mock OnyxSalesforceSQLite object."""
|
||||
return MagicMock()
|
||||
|
||||
def _create_salesforce_object(
|
||||
self,
|
||||
object_id: str,
|
||||
object_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> SalesforceObject:
|
||||
"""Helper to create a SalesforceObject with required fields."""
|
||||
# Ensure required fields are present
|
||||
data.setdefault(ID_FIELD, object_id)
|
||||
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
|
||||
data.setdefault(NAME_FIELD, f"Test {object_type}")
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_metadata_type_conversion_for_opportunity(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that Opportunity metadata fields are properly type-converted."""
|
||||
parent_id = "006bm000006kyDpAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create a parent object with various data types in the fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Test Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"Account": "Acme Corp", # string - should become "account" metadata
|
||||
"FiscalQuarter": 2, # int - should be converted to "2"
|
||||
"FiscalYear": 2024, # int - should be converted to "2024"
|
||||
"IsClosed": False, # bool - should be converted to "False"
|
||||
"StageName": "Prospecting", # string
|
||||
"Type": "New Business", # string
|
||||
"Amount": 50000.50, # float - should be converted to "50000.50"
|
||||
"CloseDate": "2024-06-30", # string
|
||||
"Probability": 75, # int - should be converted to "75"
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
# Setup mock sf_db
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create a mock document that convert_sf_object_to_doc will return
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Test Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
# Track parent changes
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# Call _yield_doc_batches
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify we got one batch with one document
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify metadata type conversions
|
||||
# All values should be strings (or list of strings)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["account"] == "Acme Corp" # string stays string
|
||||
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
|
||||
assert doc.metadata["fiscal_year"] == "2024" # int -> str
|
||||
assert doc.metadata["is_closed"] == "False" # bool -> str
|
||||
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
|
||||
assert doc.metadata["type"] == "New Business" # string stays string
|
||||
assert (
|
||||
doc.metadata["amount"] == "50000.5"
|
||||
) # float -> str (Python drops trailing zeros)
|
||||
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
|
||||
assert doc.metadata["probability"] == "75" # int -> str
|
||||
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
|
||||
|
||||
# Verify parent was counted
|
||||
assert parents_changed == 1
|
||||
assert type_to_processed[parent_type] == 1
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_missing_optional_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing optional metadata fields are not added."""
|
||||
parent_id = "006bm000006kyDqAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create parent object with only some fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Minimal Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"StageName": "Closed Won",
|
||||
# Notably missing: Amount, Probability, FiscalQuarter, etc.
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Minimal Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only present fields should be in metadata
|
||||
assert "stage_name" in doc.metadata
|
||||
assert doc.metadata["stage_name"] == "Closed Won"
|
||||
assert "name" in doc.metadata
|
||||
assert doc.metadata["name"] == "Minimal Opportunity"
|
||||
|
||||
# Missing fields should not be in metadata
|
||||
assert "amount" not in doc.metadata
|
||||
assert "probability" not in doc.metadata
|
||||
assert "fiscal_quarter" not in doc.metadata
|
||||
assert "fiscal_year" not in doc.metadata
|
||||
assert "is_closed" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_contact_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test metadata conversion for Contact object type."""
|
||||
parent_id = "003bm00000EjHCjAAN"
|
||||
parent_type = "Contact"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "John Doe",
|
||||
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
|
||||
"Account": "Globex Corp",
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z",
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="John Doe",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify Contact-specific metadata
|
||||
assert doc.metadata["object_type"] == "Contact"
|
||||
assert doc.metadata["account"] == "Globex Corp"
|
||||
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
|
||||
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_no_default_attributes_for_unknown_type(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that unknown object types only get object_type metadata."""
|
||||
parent_id = "001bm00000fd9Z3AAI"
|
||||
parent_type = "CustomObject__c"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Custom Record",
|
||||
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
|
||||
"CustomField__c": "custom value",
|
||||
"NumberField__c": 123,
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Custom Record",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only object_type should be set for unknown types
|
||||
assert doc.metadata["object_type"] == "CustomObject__c"
|
||||
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
|
||||
assert "CustomField__c" not in doc.metadata
|
||||
assert "NumberField__c" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_skips_missing_parent_objects(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing parent objects are skipped gracefully."""
|
||||
parent_id = "006bm000006kyDrAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# get_record returns None for missing object
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = None
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Should yield one empty batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 0
|
||||
|
||||
# convert_sf_object_to_doc should not have been called
|
||||
mock_convert.assert_not_called()
|
||||
|
||||
# Parents changed should still be 0
|
||||
assert parents_changed == 0
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_multiple_documents_batching(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that multiple documents are correctly batched."""
|
||||
# Create 3 parent objects
|
||||
parent_ids = [
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
]
|
||||
parent_type = "Opportunity"
|
||||
|
||||
parent_objects = [
|
||||
self._create_salesforce_object(
|
||||
pid,
|
||||
parent_type,
|
||||
{
|
||||
ID_FIELD: pid,
|
||||
NAME_FIELD: f"Opportunity {i}",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"IsClosed": i % 2 == 0, # alternating bool values
|
||||
"Amount": 1000.0 * (i + 1),
|
||||
},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
|
||||
# Setup mock to return all three
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
|
||||
)
|
||||
mock_sf_db.get_record.side_effect = parent_objects
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create mock documents
|
||||
mock_docs = [
|
||||
Document(
|
||||
id=f"SALESFORCE_{pid}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=f"Opportunity {i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
mock_convert.side_effect = mock_docs
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
# With batch_size=10, all 3 docs should be in one batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 3
|
||||
|
||||
# Verify each document has correct metadata
|
||||
for i, doc in enumerate(batches[0]):
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["is_closed"] == str(i % 2 == 0)
|
||||
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
|
||||
|
||||
assert type_to_processed[parent_type] == 3
|
||||
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
|
||||
|
||||
def _mock_user(
|
||||
user_id: UUID | None = None, email: str = "test@example.com"
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.oauth_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
def _make_query_chain() -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
return chain
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_nulls_out_document_set_ownership(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Verify DocumentSet.user_id is nulled out (update, not delete)
|
||||
doc_set_chain = query_chains[DocumentSet]
|
||||
doc_set_chain.filter.assert_called()
|
||||
doc_set_chain.filter.return_value.update.assert_called_once_with(
|
||||
{DocumentSet.user_id: None}
|
||||
)
|
||||
|
||||
# Verify Persona.user_id is nulled out (update, not delete)
|
||||
persona_chain = query_chains[Persona]
|
||||
persona_chain.filter.assert_called()
|
||||
persona_chain.filter.return_value.update.assert_called_once_with(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_cleans_up_join_tables(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Join tables should be deleted (not updated)
|
||||
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
|
||||
chain = query_chains[model]
|
||||
chain.filter.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_commits_and_removes_invited(
|
||||
_mock_ee: Any, mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user(email="deleted@example.com")
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_called_once_with(user)
|
||||
db_session.commit.assert_called_once()
|
||||
mock_remove_invited.assert_called_once_with("deleted@example.com")
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_deletes_oauth_accounts(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
oauth1 = MagicMock()
|
||||
oauth2 = MagicMock()
|
||||
user.oauth_accounts = [oauth1, oauth2]
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_any_call(oauth1)
|
||||
db_session.delete.assert_any_call(oauth2)
|
||||
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
@@ -582,29 +582,33 @@ else
|
||||
fi
|
||||
|
||||
# Ask for authentication schema
|
||||
echo ""
|
||||
print_info "Which authentication schema would you like to set up?"
|
||||
echo ""
|
||||
echo "1) Basic - Username/password authentication"
|
||||
echo "2) No Auth - Open access (development/testing)"
|
||||
echo ""
|
||||
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
|
||||
echo ""
|
||||
# echo ""
|
||||
# print_info "Which authentication schema would you like to set up?"
|
||||
# echo ""
|
||||
# echo "1) Basic - Username/password authentication"
|
||||
# echo "2) No Auth - Open access (development/testing)"
|
||||
# echo ""
|
||||
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
|
||||
# echo ""
|
||||
|
||||
case "${AUTH_CHOICE:-1}" in
|
||||
1)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Selected: Basic authentication"
|
||||
;;
|
||||
2)
|
||||
AUTH_SCHEMA="disabled"
|
||||
print_info "Selected: No authentication"
|
||||
;;
|
||||
*)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Invalid choice, using basic authentication"
|
||||
;;
|
||||
esac
|
||||
# case "${AUTH_CHOICE:-1}" in
|
||||
# 1)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Selected: Basic authentication"
|
||||
# ;;
|
||||
# # 2)
|
||||
# # AUTH_SCHEMA="disabled"
|
||||
# # print_info "Selected: No authentication"
|
||||
# # ;;
|
||||
# *)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Invalid choice, using basic authentication"
|
||||
# ;;
|
||||
# esac
|
||||
|
||||
# TODO (jessica): Uncomment this once no auth users still have an account
|
||||
# Use basic auth by default
|
||||
AUTH_SCHEMA="basic"
|
||||
|
||||
# Create .env file from template
|
||||
print_info "Creating .env file with your selections..."
|
||||
|
||||
3
desktop/.gitignore
vendored
3
desktop/.gitignore
vendored
@@ -22,3 +22,6 @@ npm-debug.log*
|
||||
# Local env files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Generated files
|
||||
src-tauri/gen/schemas/acl-manifests.json
|
||||
|
||||
96
desktop/src-tauri/Cargo.lock
generated
96
desktop/src-tauri/Cargo.lock
generated
@@ -706,16 +706,6 @@ dependencies = [
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -993,16 +983,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -1122,24 +1102,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "global-hotkey"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"keyboard-types",
|
||||
"objc2 0.6.3",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"windows-sys 0.59.0",
|
||||
"x11rb",
|
||||
"xkeysym",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -1713,12 +1675,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
@@ -2248,7 +2204,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-global-shortcut",
|
||||
"tauri-plugin-shell",
|
||||
"tauri-plugin-window-state",
|
||||
"tokio",
|
||||
@@ -2878,19 +2833,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -3605,21 +3547,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-global-shortcut"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
|
||||
dependencies = [
|
||||
"global-hotkey",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-shell"
|
||||
version = "2.3.3"
|
||||
@@ -5021,29 +4948,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xkeysym"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-global-shortcut = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -20,7 +20,6 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shortcuts Setup
|
||||
// ============================================================================
|
||||
|
||||
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
|
||||
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
|
||||
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
// Avoid hijacking the system-wide Cmd+R on macOS.
|
||||
#[cfg(target_os = "macos")]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
reload,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
shortcuts,
|
||||
move |_app, shortcut, _event| {
|
||||
if shortcut == &new_chat {
|
||||
trigger_new_chat(&app_handle);
|
||||
}
|
||||
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if shortcut == &reload {
|
||||
let _ = window.eval("window.location.reload()");
|
||||
} else if shortcut == &back {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if shortcut == &new_window_shortcut {
|
||||
trigger_new_window(&app_handle);
|
||||
} else if shortcut == &show_app {
|
||||
focus_main_window(&app_handle);
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+Space"),
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
@@ -666,7 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -698,11 +629,6 @@ fn main() {
|
||||
.setup(move |app| {
|
||||
let app_handle = app.handle();
|
||||
|
||||
// Setup global shortcuts
|
||||
if let Err(e) = setup_shortcuts(&app_handle) {
|
||||
eprintln!("Failed to setup shortcuts: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = setup_app_menu(&app_handle) {
|
||||
eprintln!("Failed to setup menu: {}", e);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background-900: #1a1a1a;
|
||||
--background-800: #262626;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.08);
|
||||
--white-15: rgba(255, 255, 255, 0.12);
|
||||
--white-20: rgba(255, 255, 255, 0.15);
|
||||
--white-30: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
@@ -30,7 +41,11 @@
|
||||
|
||||
body {
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
@@ -39,6 +54,9 @@
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
color 0.3s ease;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
@@ -69,16 +87,19 @@
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
background: var(--background-800);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
border 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .settings-panel {
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
@@ -93,17 +114,19 @@
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
background: var(--background-900);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
color: var(--text-light-05);
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
@@ -134,9 +157,10 @@
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
background: var(--background-900);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
@@ -176,7 +200,7 @@
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-800);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
@@ -186,8 +210,8 @@
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-900);
|
||||
box-shadow: 0 0 0 2px var(--white-10);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
@@ -231,7 +255,7 @@
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
background-color: var(--white-15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
@@ -243,14 +267,18 @@
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
background-color: var(--background-800);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.dark .toggle-slider:before {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
background-color: var(--white-30);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
@@ -288,14 +316,15 @@
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-15);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
@@ -372,10 +401,34 @@
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
// Theme detection based on system preferences
|
||||
function applySystemTheme() {
|
||||
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
||||
|
||||
function updateTheme(e) {
|
||||
if (e.matches) {
|
||||
document.documentElement.classList.add("dark");
|
||||
document.body.classList.add("dark");
|
||||
} else {
|
||||
document.documentElement.classList.remove("dark");
|
||||
document.body.classList.remove("dark");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply initial theme
|
||||
updateTheme(darkModeQuery);
|
||||
|
||||
// Listen for changes
|
||||
darkModeQuery.addEventListener("change", updateTheme);
|
||||
}
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Apply system theme immediately
|
||||
applySystemTheme();
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
|
||||
@@ -113,6 +113,23 @@
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function updateTitleBarTheme(isDark) {
|
||||
const titleBar = document.getElementById(TITLEBAR_ID);
|
||||
if (!titleBar) return;
|
||||
|
||||
if (isDark) {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
|
||||
} else {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
|
||||
}
|
||||
}
|
||||
|
||||
function buildTitleBar() {
|
||||
const titleBar = document.createElement("div");
|
||||
titleBar.id = TITLEBAR_ID;
|
||||
@@ -134,6 +151,11 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Apply initial styles matching current theme
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
// Apply styles matching Onyx design system with translucent glass effect
|
||||
titleBar.style.cssText = `
|
||||
position: fixed;
|
||||
@@ -156,8 +178,12 @@
|
||||
-webkit-backdrop-filter: blur(18px) saturate(180%);
|
||||
-webkit-app-region: drag;
|
||||
padding: 0 12px;
|
||||
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
|
||||
`;
|
||||
|
||||
// Apply correct theme
|
||||
updateTitleBarTheme(isDark);
|
||||
|
||||
return titleBar;
|
||||
}
|
||||
|
||||
@@ -168,6 +194,11 @@
|
||||
|
||||
const existing = document.getElementById(TITLEBAR_ID);
|
||||
if (existing?.parentElement === document.body) {
|
||||
// Update theme on existing titlebar
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,6 +209,14 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
|
||||
// Ensure theme is applied immediately after mount
|
||||
setTimeout(() => {
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
@@ -194,9 +233,66 @@
|
||||
}
|
||||
}
|
||||
|
||||
function observeThemeChanges() {
|
||||
let lastKnownTheme = null;
|
||||
|
||||
function checkAndUpdateTheme() {
|
||||
// Check both html and body for dark class (some apps use body)
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
if (lastKnownTheme !== isDark) {
|
||||
lastKnownTheme = isDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}
|
||||
}
|
||||
|
||||
// Immediate check on setup
|
||||
checkAndUpdateTheme();
|
||||
|
||||
// Watch for theme changes on the HTML element
|
||||
const themeObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
|
||||
themeObserver.observe(document.documentElement, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
|
||||
// Also observe body if it exists
|
||||
if (document.body) {
|
||||
const bodyObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
bodyObserver.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
}
|
||||
|
||||
// Also check periodically in case classList is manipulated directly
|
||||
// or the theme loads asynchronously after page load
|
||||
const intervalId = setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 300);
|
||||
|
||||
// Clean up after 30 seconds once theme should be stable
|
||||
setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
// But keep checking every 2 seconds for manual theme changes
|
||||
setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 2000);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
function init() {
|
||||
mountTitleBar();
|
||||
syncViewportHeight();
|
||||
observeThemeChanges();
|
||||
|
||||
window.addEventListener("resize", syncViewportHeight, { passive: true });
|
||||
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
|
||||
passive: true,
|
||||
|
||||
@@ -119,7 +119,7 @@ backend = [
|
||||
"shapely==2.0.6",
|
||||
"stripe==10.12.0",
|
||||
"urllib3==2.6.3",
|
||||
"mistune==0.8.4",
|
||||
"mistune==3.2.0",
|
||||
"sendgrid==6.12.5",
|
||||
"exa_py==1.15.4",
|
||||
"braintrust==0.3.9",
|
||||
@@ -142,7 +142,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.4.0",
|
||||
"onyx-devtools==0.6.2",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
26
uv.lock
generated
26
uv.lock
generated
@@ -3897,11 +3897,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mistune"
|
||||
version = "0.8.4"
|
||||
version = "3.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2d/a4/509f6e7783ddd35482feda27bc7f72e65b5e7dc910eca4ab2164daf9c577/mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", size = 58322, upload-time = "2018-10-11T06:59:27.908Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/09/ec/4b43dae793655b7d8a25f76119624350b4d65eb663459eb9603d7f1f0345/mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4", size = 16220, upload-time = "2018-10-11T06:59:26.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4766,7 +4766,7 @@ requires-dist = [
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
|
||||
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.25.0" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==0.8.4" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
|
||||
@@ -4775,7 +4775,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.4.0" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4878,20 +4878,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.4.0"
|
||||
version = "0.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/d8/f68d15c12d27d4525d10697ac7e2d67d6122fb59ccab219afb2973bc33ad/onyx_devtools-0.4.0-py3-none-any.whl", hash = "sha256:3eb821bce7ec8651d57e937d4d8483e1c2c4bc51df8cbab2dbcc05e3740ec96c", size = 2870841, upload-time = "2026-01-23T04:44:32.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/04/6376342389494b51fd89e554dfdaf0d3809b8d1473bc9b72abd2d7dba21e/onyx_devtools-0.4.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:144e518abad3031ffef189445a69356fca1da2a4fb40c7b8431550133bfc4eef", size = 2890308, upload-time = "2026-01-23T04:44:37.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/c1/859b32fb3eff7e67179d971ace36313ae64e7fc9a242b45e606138b0041f/onyx_devtools-0.4.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0cc74d561f08a9c894adf8de79855b4fc72eb70e823a75e29db7f625ad366bd7", size = 2696160, upload-time = "2026-01-23T04:44:30.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/1b/f1e3f574e9917779d22e3fcb28f8ac1888c250e7452a523f64a6ab8a1759/onyx_devtools-0.4.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d69de76a97d7f9ff8c473afffbf544a65265645d726f3d70cc12dbbd7e364222", size = 2602134, upload-time = "2026-01-23T04:44:31.716Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/4a/a5d11640fdc23c9bf0e8617ce13793a587e49a64be2d20badf7e9b045e0a/onyx_devtools-0.4.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:fa84980ce8830e35432831aadc19ff465dbc723605aa80c50e0debc58457b70f", size = 2870864, upload-time = "2026-01-23T04:44:31.5Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/9f/6a7e02fbf47bcaea4d02b0ed92bea6e2c09408be7654fb3b57a1ba9863f2/onyx_devtools-0.4.0-py3-none-win_amd64.whl", hash = "sha256:8451efe3e137157696decf8b60a19fb3f0c52ae9f2d9b7c5bc6e667900e7c61e", size = 2953545, upload-time = "2026-01-23T04:44:38.11Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/42/f7a5b99ade06d215fb99de41181d51a9a984f83afb15afa15ce79ecab635/onyx_devtools-0.4.0-py3-none-win_arm64.whl", hash = "sha256:53a5942c922d7049650e934c43f9c057d046f8d53bc68935ebf7e93baa29afc3", size = 2665984, upload-time = "2026-01-23T04:44:29.399Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { SVGProps } from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
|
||||
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 15 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
21
web/lib/opal/src/icons/branch.tsx
Normal file
21
web/lib/opal/src/icons/branch.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBranch = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M4.75001 5C5.71651 5 6.50001 4.2165 6.50001 3.25C6.50001 2.2835 5.7165 1.5 4.75 1.5C3.78351 1.5 3.00001 2.2835 3.00001 3.25C3.00001 4.2165 3.78351 5 4.75001 5ZM4.75001 5L4.75001 6.24999M4.75 11C3.7835 11 3 11.7835 3 12.75C3 13.7165 3.7835 14.5 4.75 14.5C5.7165 14.5 6.5 13.7165 6.5 12.75C6.5 11.7835 5.71649 11 4.75 11ZM4.75 11L4.75001 6.24999M10.5 8.74997C10.5 9.71646 11.2835 10.5 12.25 10.5C13.2165 10.5 14 9.71646 14 8.74997C14 7.78347 13.2165 7 12.25 7C11.2835 7 10.5 7.78347 10.5 8.74997ZM10.5 8.74997L7.25001 8.74999C5.8693 8.74999 4.75001 7.6307 4.75001 6.24999"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgBranch;
|
||||
16
web/lib/opal/src/icons/circle.tsx
Normal file
16
web/lib/opal/src/icons/circle.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgCircle = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<circle cx="8" cy="8" r="6" strokeWidth={1.5} />
|
||||
</svg>
|
||||
);
|
||||
export default SvgCircle;
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgClaude = (props: IconProps) => {
|
||||
const SvgClaude = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgClipboard = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
21
web/lib/opal/src/icons/download.tsx
Normal file
21
web/lib/opal/src/icons/download.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgDownload = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M14 10V12.6667C14 13.3929 13.3929 14 12.6667 14H3.33333C2.60711 14 2 13.3929 2 12.6667V10M4.66667 6.66667L8 10M8 10L11.3333 6.66667M8 10L8 2"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgDownload;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
|
||||
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
|
||||
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
|
||||
export { default as SvgBranch } from "@opal/icons/branch";
|
||||
export { default as SvgBubbleText } from "@opal/icons/bubble-text";
|
||||
export { default as SvgCalendar } from "@opal/icons/calendar";
|
||||
export { default as SvgCheck } from "@opal/icons/check";
|
||||
@@ -36,6 +37,7 @@ export { default as SvgChevronLeft } from "@opal/icons/chevron-left";
|
||||
export { default as SvgChevronRight } from "@opal/icons/chevron-right";
|
||||
export { default as SvgChevronUp } from "@opal/icons/chevron-up";
|
||||
export { default as SvgChevronUpSmall } from "@opal/icons/chevron-up-small";
|
||||
export { default as SvgCircle } from "@opal/icons/circle";
|
||||
export { default as SvgClaude } from "@opal/icons/claude";
|
||||
export { default as SvgClipboard } from "@opal/icons/clipboard";
|
||||
export { default as SvgClock } from "@opal/icons/clock";
|
||||
@@ -46,6 +48,7 @@ export { default as SvgCopy } from "@opal/icons/copy";
|
||||
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
|
||||
export { default as SvgCpu } from "@opal/icons/cpu";
|
||||
export { default as SvgDevKit } from "@opal/icons/dev-kit";
|
||||
export { default as SvgDownload } from "@opal/icons/download";
|
||||
export { default as SvgDiscordMono } from "@opal/icons/DiscordMono";
|
||||
export { default as SvgDownloadCloud } from "@opal/icons/download-cloud";
|
||||
export { default as SvgEdit } from "@opal/icons/edit";
|
||||
@@ -135,6 +138,7 @@ export { default as SvgStep3End } from "@opal/icons/step3-end";
|
||||
export { default as SvgStop } from "@opal/icons/stop";
|
||||
export { default as SvgStopCircle } from "@opal/icons/stop-circle";
|
||||
export { default as SvgSun } from "@opal/icons/sun";
|
||||
export { default as SvgTerminal } from "@opal/icons/terminal";
|
||||
export { default as SvgTerminalSmall } from "@opal/icons/terminal-small";
|
||||
export { default as SvgTextLinesSmall } from "@opal/icons/text-lines-small";
|
||||
export { default as SvgThumbsDown } from "@opal/icons/thumbs-down";
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const OnyxLogo = ({
|
||||
width = 24,
|
||||
height = 24,
|
||||
className,
|
||||
...props
|
||||
}: IconProps) => (
|
||||
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={width}
|
||||
height={height}
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
@@ -23,4 +17,4 @@ const OnyxLogo = ({
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default OnyxLogo;
|
||||
export default SvgOnyxLogo;
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgOpenAI = (props: IconProps) => {
|
||||
const SvgOpenAI = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
22
web/lib/opal/src/icons/terminal.tsx
Normal file
22
web/lib/opal/src/icons/terminal.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgTerminal = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2.66667 11.3333L6.66667 7.33331L2.66667 3.33331M8.00001 12.6666H13.3333"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgTerminal;
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgUserPlus = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgWallet = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -300,13 +300,7 @@ export default function Page() {
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={
|
||||
<SvgOnyxLogo
|
||||
width={32}
|
||||
height={32}
|
||||
className="my-auto stroke-text-04"
|
||||
/>
|
||||
}
|
||||
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</>
|
||||
|
||||
@@ -31,6 +31,7 @@ import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
@@ -135,7 +136,7 @@ function BedrockFormInternals({
|
||||
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<SelectorFormField
|
||||
@@ -176,7 +177,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
label="AWS Access Key ID"
|
||||
@@ -191,7 +192,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
label="AWS Bedrock Long-term API Key"
|
||||
|
||||
@@ -131,10 +131,15 @@ export function CustomForm({
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(
|
||||
values.custom_config_list
|
||||
),
|
||||
|
||||
@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
|
||||
function OllamaFormContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
|
||||
existingLlmProvider?.model_configurations || []
|
||||
);
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -70,16 +71,25 @@ function OllamaFormContent({
|
||||
.then((data) => {
|
||||
if (data.error) {
|
||||
console.error("Error fetching models:", data.error);
|
||||
setAvailableModels([]);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setAvailableModels(data.models);
|
||||
setFetchedModels(data.models);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingModels(false);
|
||||
});
|
||||
}
|
||||
}, [formikProps.values.api_base]);
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
existingLlmProvider?.name,
|
||||
setFetchedModels,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
@@ -99,7 +109,7 @@ function OllamaFormContent({
|
||||
/>
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={availableModels}
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
|
||||
isLoading={isLoadingModels}
|
||||
@@ -125,6 +135,8 @@ export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
@@ -189,7 +201,10 @@ export function OllamaForm({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
@@ -205,6 +220,8 @@ export function OllamaForm({
|
||||
<OllamaFormContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
testError={testError}
|
||||
mutate={mutate}
|
||||
|
||||
@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo
|
||||
width={24}
|
||||
height={24}
|
||||
className="text-text-04 shrink-0"
|
||||
/>
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1168,7 +1168,7 @@ export default function Page() {
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={16} height={16} />
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
@@ -1381,7 +1381,7 @@ export default function Page() {
|
||||
} logo`,
|
||||
fallback:
|
||||
selectedContentProviderType === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
|
||||
<SvgOnyxLogo size={24} className="text-text-05" />
|
||||
) : undefined,
|
||||
size: 24,
|
||||
containerSize: 28,
|
||||
|
||||
@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
<BackButton />
|
||||
<div
|
||||
className="flex
|
||||
items-center
|
||||
|
||||
@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { ChatPopup } from "@/app/chat/components/ChatPopup";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextView from "@/components/chat/TextView";
|
||||
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
if (liveAssistant) {
|
||||
return liveAssistant.tools.some(
|
||||
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
return personaIncludesRetrieval(liveAssistant);
|
||||
}
|
||||
return false;
|
||||
}, [liveAssistant]);
|
||||
|
||||
@@ -860,6 +860,7 @@ export function useChatController({
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetsVersion: packetsVersion,
|
||||
packetCount: packets.length,
|
||||
},
|
||||
],
|
||||
// Pass the latest map state
|
||||
@@ -886,6 +887,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
},
|
||||
{
|
||||
nodeId: initialAssistantNode.nodeId,
|
||||
@@ -895,6 +897,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
stackTrace: stackTrace,
|
||||
errorCode: errorCode,
|
||||
isRetryable: isRetryable,
|
||||
|
||||
@@ -141,6 +141,7 @@ export interface Message {
|
||||
packets: Packet[];
|
||||
// Version counter for efficient memo comparison (increments with each packet)
|
||||
packetsVersion?: number;
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
|
||||
// cached values for easy access
|
||||
documents?: OnyxDocument[] | null;
|
||||
|
||||
@@ -11,6 +11,7 @@ import { CitationMap } from "../../interfaces";
|
||||
export enum RenderType {
|
||||
HIGHLIGHT = "highlight",
|
||||
FULL = "full",
|
||||
COMPACT = "compact",
|
||||
}
|
||||
|
||||
export interface FullChatState {
|
||||
@@ -35,6 +36,9 @@ export interface RendererResult {
|
||||
// used for things that should just show text w/o an icon or header
|
||||
// e.g. ReasoningRenderer
|
||||
expandedText?: JSX.Element;
|
||||
|
||||
// Whether this renderer supports compact mode (collapse button shown only when true)
|
||||
supportsCompact?: boolean;
|
||||
}
|
||||
|
||||
export type MessageRenderer<
|
||||
@@ -48,5 +52,7 @@ export type MessageRenderer<
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
/** Whether this is the last step in the timeline (for connector line decisions) */
|
||||
isLastStep?: boolean;
|
||||
children: (result: RendererResult) => JSX.Element;
|
||||
}>;
|
||||
|
||||
@@ -68,10 +68,11 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
|
||||
|
||||
const icon = FiTool;
|
||||
|
||||
if (renderType === RenderType.HIGHLIGHT) {
|
||||
if (renderType === RenderType.COMPACT) {
|
||||
return children({
|
||||
icon,
|
||||
status: status,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{isRunning && `${toolName} running...`}
|
||||
@@ -84,6 +85,7 @@ export const CustomToolRenderer: MessageRenderer<CustomToolPacket, {}> = ({
|
||||
return children({
|
||||
icon,
|
||||
status,
|
||||
supportsCompact: true,
|
||||
content: (
|
||||
<div className="flex flex-col gap-3">
|
||||
{/* File responses */}
|
||||
|
||||
@@ -72,6 +72,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Generating images...",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex flex-col">
|
||||
<div>
|
||||
@@ -89,6 +90,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
status: `Generated ${images.length} image${
|
||||
images.length !== 1 ? "s" : ""
|
||||
}`,
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex flex-col my-1">
|
||||
{images.length > 0 ? (
|
||||
@@ -122,6 +124,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: status,
|
||||
supportsCompact: false,
|
||||
content: <div></div>,
|
||||
});
|
||||
}
|
||||
@@ -131,6 +134,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Generating image...",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<div className="flex gap-0.5">
|
||||
@@ -154,6 +158,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Image generation failed",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-red-600 dark:text-red-400">
|
||||
Image generation failed
|
||||
@@ -166,6 +171,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: `Generated ${images.length} image${images.length > 1 ? "s" : ""}`,
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">
|
||||
Generated {images.length} image
|
||||
@@ -178,6 +184,7 @@ export const ImageToolRenderer: MessageRenderer<
|
||||
return children({
|
||||
icon: FiImage,
|
||||
status: "Image generation",
|
||||
supportsCompact: false,
|
||||
content: (
|
||||
<div className="text-sm text-muted-foreground">Image generation</div>
|
||||
),
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
"use client";
|
||||
|
||||
import React, { FunctionComponent, useMemo, useCallback } from "react";
|
||||
import { StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState } from "../interfaces";
|
||||
import { TurnGroup, TransformedStep } from "./transformers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
|
||||
import { SvgCheckCircle, SvgStopCircle } from "@opal/icons";
|
||||
import { IconProps } from "@opal/types";
|
||||
import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererResult,
|
||||
} from "./TimelineRendererComponent";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { ParallelTimelineTabs } from "./ParallelTimelineTabs";
|
||||
import { StepContainer } from "./StepContainer";
|
||||
import {
|
||||
useTimelineExpansion,
|
||||
useTimelineMetrics,
|
||||
useTimelineHeader,
|
||||
} from "@/app/chat/message/messageComponents/timeline/hooks";
|
||||
import {
|
||||
isResearchAgentPackets,
|
||||
stepSupportsCompact,
|
||||
} from "@/app/chat/message/messageComponents/timeline/packetHelpers";
|
||||
import {
|
||||
StreamingHeader,
|
||||
CollapsedHeader,
|
||||
ExpandedHeader,
|
||||
StoppedHeader,
|
||||
ParallelStreamingHeader,
|
||||
} from "@/app/chat/message/messageComponents/timeline/headers";
|
||||
|
||||
// =============================================================================
|
||||
// TimelineStep Component - Memoized to prevent re-renders
|
||||
// =============================================================================
|
||||
|
||||
interface TimelineStepProps {
|
||||
step: TransformedStep;
|
||||
chatState: FullChatState;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
isLastStep: boolean;
|
||||
isFirstStep: boolean;
|
||||
isSingleStep: boolean;
|
||||
}
|
||||
|
||||
//will be removed on cleanup
|
||||
const noopCallback = () => {};
|
||||
|
||||
const TimelineStep = React.memo(function TimelineStep({
|
||||
step,
|
||||
chatState,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
isLastStep,
|
||||
isFirstStep,
|
||||
isSingleStep,
|
||||
}: TimelineStepProps) {
|
||||
// Stable render callback - doesn't need to change between renders
|
||||
const renderStep = useCallback(
|
||||
({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
isLastStep: rendererIsLastStep,
|
||||
supportsCompact,
|
||||
}: TimelineRendererResult) =>
|
||||
isResearchAgentPackets(step.packets) ? (
|
||||
content
|
||||
) : (
|
||||
<StepContainer
|
||||
stepIcon={icon as FunctionComponent<IconProps> | undefined}
|
||||
header={status}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={onToggle}
|
||||
collapsible={true}
|
||||
supportsCompact={supportsCompact}
|
||||
isLastStep={rendererIsLastStep}
|
||||
isFirstStep={isFirstStep}
|
||||
hideHeader={isSingleStep}
|
||||
>
|
||||
{content}
|
||||
</StepContainer>
|
||||
),
|
||||
[step.packets, isFirstStep, isSingleStep]
|
||||
);
|
||||
|
||||
return (
|
||||
<TimelineRendererComponent
|
||||
packets={step.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopCallback}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={true}
|
||||
isLastStep={isLastStep}
|
||||
>
|
||||
{renderStep}
|
||||
</TimelineRendererComponent>
|
||||
);
|
||||
});
|
||||
|
||||
// =============================================================================
|
||||
// Main Component
|
||||
// =============================================================================
|
||||
|
||||
export interface AgentTimelineProps {
|
||||
/** Turn groups from usePacketProcessor */
|
||||
turnGroups: TurnGroup[];
|
||||
/** Chat state for rendering content */
|
||||
chatState: FullChatState;
|
||||
/** Whether the stop packet has been seen */
|
||||
stopPacketSeen?: boolean;
|
||||
/** Reason for stopping (if stopped) */
|
||||
stopReason?: StopReason;
|
||||
/** Whether final answer is coming (affects last connector) */
|
||||
finalAnswerComing?: boolean;
|
||||
/** Whether there is display content after timeline */
|
||||
hasDisplayContent?: boolean;
|
||||
/** Content to render after timeline (final message + toolbar) - slot pattern */
|
||||
children?: React.ReactNode;
|
||||
/** Whether the timeline is collapsible */
|
||||
collapsible?: boolean;
|
||||
/** Title of the button to toggle the timeline */
|
||||
buttonTitle?: string;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
/** Test ID for e2e testing */
|
||||
"data-testid"?: string;
|
||||
/** Unique tool names (pre-computed for performance) */
|
||||
uniqueToolNames?: string[];
|
||||
}
|
||||
|
||||
export function AgentTimeline({
|
||||
turnGroups,
|
||||
chatState,
|
||||
stopPacketSeen = false,
|
||||
stopReason,
|
||||
finalAnswerComing = false,
|
||||
hasDisplayContent = false,
|
||||
collapsible = true,
|
||||
buttonTitle,
|
||||
className,
|
||||
"data-testid": testId,
|
||||
uniqueToolNames = [],
|
||||
}: AgentTimelineProps) {
|
||||
// Header text and state flags
|
||||
const { headerText, hasPackets, userStopped } = useTimelineHeader(
|
||||
turnGroups,
|
||||
stopReason
|
||||
);
|
||||
|
||||
// Memoized metrics derived from turn groups
|
||||
const {
|
||||
totalSteps,
|
||||
isSingleStep,
|
||||
uniqueTools,
|
||||
lastTurnGroup,
|
||||
lastStep,
|
||||
lastStepIsResearchAgent,
|
||||
lastStepSupportsCompact,
|
||||
} = useTimelineMetrics(turnGroups, uniqueToolNames, userStopped);
|
||||
|
||||
// Expansion state management
|
||||
const { isExpanded, handleToggle, parallelActiveTab, setParallelActiveTab } =
|
||||
useTimelineExpansion(stopPacketSeen, lastTurnGroup);
|
||||
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderContentOnly = useCallback(
|
||||
({ content }: TimelineRendererResult) => content,
|
||||
[]
|
||||
);
|
||||
|
||||
// Parallel step analysis for collapsed streaming view
|
||||
const parallelActiveStep = useMemo(() => {
|
||||
if (!lastTurnGroup?.isParallel) return null;
|
||||
return (
|
||||
lastTurnGroup.steps.find((s) => s.key === parallelActiveTab) ??
|
||||
lastTurnGroup.steps[0]
|
||||
);
|
||||
}, [lastTurnGroup, parallelActiveTab]);
|
||||
|
||||
const parallelActiveStepSupportsCompact = useMemo(() => {
|
||||
if (!parallelActiveStep) return false;
|
||||
return (
|
||||
stepSupportsCompact(parallelActiveStep.packets) &&
|
||||
!isResearchAgentPackets(parallelActiveStep.packets)
|
||||
);
|
||||
}, [parallelActiveStep]);
|
||||
|
||||
// Collapsed streaming: show compact content below header
|
||||
const showCollapsedCompact =
|
||||
!stopPacketSeen &&
|
||||
!isExpanded &&
|
||||
lastStep &&
|
||||
!lastTurnGroup?.isParallel &&
|
||||
!lastStepIsResearchAgent &&
|
||||
lastStepSupportsCompact;
|
||||
|
||||
// Parallel tabs in header only when collapsed (expanded view has tabs in content)
|
||||
const showParallelTabs =
|
||||
!stopPacketSeen &&
|
||||
!isExpanded &&
|
||||
lastTurnGroup?.isParallel &&
|
||||
lastTurnGroup.steps.length > 0;
|
||||
|
||||
// Collapsed parallel compact content
|
||||
const showCollapsedParallel =
|
||||
showParallelTabs && !isExpanded && parallelActiveStepSupportsCompact;
|
||||
|
||||
// Done indicator conditions
|
||||
const showDoneIndicator =
|
||||
stopPacketSeen && isExpanded && !userStopped && !lastStepIsResearchAgent;
|
||||
|
||||
// Header selection based on state
|
||||
const renderHeader = () => {
|
||||
if (!stopPacketSeen) {
|
||||
if (showParallelTabs && lastTurnGroup) {
|
||||
return (
|
||||
<ParallelStreamingHeader
|
||||
steps={lastTurnGroup.steps}
|
||||
activeTab={parallelActiveTab}
|
||||
onTabChange={setParallelActiveTab}
|
||||
collapsible={collapsible}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<StreamingHeader
|
||||
headerText={headerText}
|
||||
collapsible={collapsible}
|
||||
buttonTitle={buttonTitle}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (userStopped) {
|
||||
return (
|
||||
<StoppedHeader
|
||||
totalSteps={totalSteps}
|
||||
collapsible={collapsible}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isExpanded) {
|
||||
return (
|
||||
<CollapsedHeader
|
||||
uniqueTools={uniqueTools}
|
||||
totalSteps={totalSteps}
|
||||
collapsible={collapsible}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <ExpandedHeader collapsible={collapsible} onToggle={handleToggle} />;
|
||||
};
|
||||
|
||||
// Empty state: no packets, still streaming
|
||||
if (!hasPackets && !hasDisplayContent) {
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
<div className="flex w-full h-full items-center px-2">
|
||||
<Text
|
||||
as="p"
|
||||
mainUiAction
|
||||
text03
|
||||
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
|
||||
>
|
||||
{headerText}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Display content only (no timeline steps)
|
||||
if (hasDisplayContent && !hasPackets) {
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
{/* Header row */}
|
||||
<div className="flex w-full h-9">
|
||||
<div className="flex justify-center items-center size-9">
|
||||
<AgentAvatar agent={chatState.assistant} size={24} />
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-full h-full items-center justify-between px-2",
|
||||
(!stopPacketSeen || userStopped || isExpanded) &&
|
||||
"bg-background-tint-00 rounded-t-12",
|
||||
!isExpanded &&
|
||||
!showCollapsedCompact &&
|
||||
!showCollapsedParallel &&
|
||||
"rounded-b-12"
|
||||
)}
|
||||
>
|
||||
{renderHeader()}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Collapsed streaming view - single step compact mode */}
|
||||
{showCollapsedCompact && lastStep && (
|
||||
<div className="flex w-full">
|
||||
<div className="w-9" />
|
||||
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
|
||||
<TimelineRendererComponent
|
||||
key={`${lastStep.key}-compact`}
|
||||
packets={lastStep.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={true}
|
||||
stopPacketSeen={false}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={false}
|
||||
isLastStep={true}
|
||||
>
|
||||
{renderContentOnly}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Collapsed streaming view - parallel tools compact mode */}
|
||||
{showCollapsedParallel && parallelActiveStep && (
|
||||
<div className="flex w-full">
|
||||
<div className="w-9" />
|
||||
<div className="w-full bg-background-tint-00 rounded-b-12 px-2 pb-2">
|
||||
<TimelineRendererComponent
|
||||
key={`${parallelActiveStep.key}-compact`}
|
||||
packets={parallelActiveStep.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={true}
|
||||
stopPacketSeen={false}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={false}
|
||||
isLastStep={true}
|
||||
>
|
||||
{renderContentOnly}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Expanded timeline view */}
|
||||
{isExpanded && (
|
||||
<div className="w-full">
|
||||
{turnGroups.map((turnGroup, turnIdx) =>
|
||||
turnGroup.isParallel ? (
|
||||
<ParallelTimelineTabs
|
||||
key={turnGroup.turnIndex}
|
||||
turnGroup={turnGroup}
|
||||
chatState={chatState}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastTurnGroup={turnIdx === turnGroups.length - 1}
|
||||
/>
|
||||
) : (
|
||||
turnGroup.steps.map((step, stepIdx) => {
|
||||
const stepIsLast =
|
||||
turnIdx === turnGroups.length - 1 &&
|
||||
stepIdx === turnGroup.steps.length - 1 &&
|
||||
!showDoneIndicator &&
|
||||
!userStopped;
|
||||
const stepIsFirst = turnIdx === 0 && stepIdx === 0;
|
||||
|
||||
return (
|
||||
<TimelineStep
|
||||
key={step.key}
|
||||
step={step}
|
||||
chatState={chatState}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastStep={stepIsLast}
|
||||
isFirstStep={stepIsFirst}
|
||||
isSingleStep={isSingleStep}
|
||||
/>
|
||||
);
|
||||
})
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Done indicator */}
|
||||
{stopPacketSeen && isExpanded && !userStopped && (
|
||||
<StepContainer
|
||||
stepIcon={SvgCheckCircle}
|
||||
header="Done"
|
||||
isLastStep={true}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{null}
|
||||
</StepContainer>
|
||||
)}
|
||||
|
||||
{/* Stopped indicator */}
|
||||
{stopPacketSeen && isExpanded && userStopped && (
|
||||
<StepContainer
|
||||
stepIcon={SvgStopCircle}
|
||||
header="Stopped"
|
||||
isLastStep={true}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{null}
|
||||
</StepContainer>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default AgentTimeline;
|
||||
@@ -0,0 +1,130 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useState,
|
||||
useMemo,
|
||||
useCallback,
|
||||
FunctionComponent,
|
||||
} from "react";
|
||||
import { StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState } from "../interfaces";
|
||||
import { TurnGroup } from "./transformers";
|
||||
import { getToolName, getToolIcon } from "../toolDisplayHelpers";
|
||||
import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererResult,
|
||||
} from "./TimelineRendererComponent";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { SvgBranch } from "@opal/icons";
|
||||
import { StepContainer } from "./StepContainer";
|
||||
import { isResearchAgentPackets } from "@/app/chat/message/messageComponents/timeline/packetHelpers";
|
||||
import { IconProps } from "@/components/icons/icons";
|
||||
|
||||
export interface ParallelTimelineTabsProps {
|
||||
/** Turn group containing parallel steps */
|
||||
turnGroup: TurnGroup;
|
||||
/** Chat state for rendering content */
|
||||
chatState: FullChatState;
|
||||
/** Whether the stop packet has been seen */
|
||||
stopPacketSeen: boolean;
|
||||
/** Reason for stopping (if stopped) */
|
||||
stopReason?: StopReason;
|
||||
/** Whether this is the last turn group (affects connector line) */
|
||||
isLastTurnGroup: boolean;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ParallelTimelineTabs({
|
||||
turnGroup,
|
||||
chatState,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
isLastTurnGroup,
|
||||
className,
|
||||
}: ParallelTimelineTabsProps) {
|
||||
const [activeTab, setActiveTab] = useState(turnGroup.steps[0]?.key ?? "");
|
||||
|
||||
// Find the active step based on selected tab
|
||||
const activeStep = useMemo(
|
||||
() => turnGroup.steps.find((step) => step.key === activeTab),
|
||||
[turnGroup.steps, activeTab]
|
||||
);
|
||||
//will be removed on cleanup
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderTabContent = useCallback(
|
||||
({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
isLastStep,
|
||||
}: TimelineRendererResult) =>
|
||||
isResearchAgentPackets(activeStep?.packets ?? []) ? (
|
||||
content
|
||||
) : (
|
||||
<StepContainer
|
||||
stepIcon={icon as FunctionComponent<IconProps> | undefined}
|
||||
header={status}
|
||||
isExpanded={isExpanded}
|
||||
onToggle={onToggle}
|
||||
collapsible={true}
|
||||
isLastStep={isLastStep}
|
||||
isFirstStep={false}
|
||||
>
|
||||
{content}
|
||||
</StepContainer>
|
||||
),
|
||||
[activeStep?.packets]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tabs value={activeTab} onValueChange={setActiveTab}>
|
||||
<div className="flex flex-col w-full gap-1">
|
||||
<div className="flex w-full">
|
||||
{/* Left column: Icon + connector line */}
|
||||
<div className="flex flex-col items-center w-9 pt-2">
|
||||
<div className="size-4 flex items-center justify-center stroke-text-02">
|
||||
<SvgBranch className="w-4 h-4" />
|
||||
</div>
|
||||
{/* Connector line */}
|
||||
<div className="w-px flex-1 bg-border-01" />
|
||||
</div>
|
||||
|
||||
{/* Right column: Tabs */}
|
||||
<div className="flex-1">
|
||||
<Tabs.List variant="pill">
|
||||
{turnGroup.steps.map((step) => (
|
||||
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
|
||||
<span className="flex items-center gap-1.5">
|
||||
{getToolIcon(step.packets)}
|
||||
{getToolName(step.packets)}
|
||||
</span>
|
||||
</Tabs.Trigger>
|
||||
))}
|
||||
</Tabs.List>
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<TimelineRendererComponent
|
||||
key={activeTab}
|
||||
packets={activeStep?.packets ?? []}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
defaultExpanded={true}
|
||||
isLastStep={isLastTurnGroup}
|
||||
>
|
||||
{renderTabContent}
|
||||
</TimelineRendererComponent>
|
||||
</div>
|
||||
</div>
|
||||
</Tabs>
|
||||
);
|
||||
}
|
||||
|
||||
export default ParallelTimelineTabs;
|
||||
@@ -0,0 +1,108 @@
|
||||
import React, { FunctionComponent } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { IconProps } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StepContainerProps {
|
||||
/** Main content */
|
||||
children?: React.ReactNode;
|
||||
/** Step icon component */
|
||||
stepIcon?: FunctionComponent<IconProps>;
|
||||
/** Header left slot */
|
||||
header?: React.ReactNode;
|
||||
/** Button title for toggle */
|
||||
buttonTitle?: string;
|
||||
/** Controlled expanded state */
|
||||
isExpanded?: boolean;
|
||||
/** Toggle callback */
|
||||
onToggle?: () => void;
|
||||
/** Whether collapse control is shown */
|
||||
collapsible?: boolean;
|
||||
/** Collapse button shown only when renderer supports compact mode */
|
||||
supportsCompact?: boolean;
|
||||
/** Additional class names */
|
||||
className?: string;
|
||||
/** Last step (no bottom connector) */
|
||||
isLastStep?: boolean;
|
||||
/** First step (top padding instead of connector) */
|
||||
isFirstStep?: boolean;
|
||||
/** Hide header (single-step timelines) */
|
||||
hideHeader?: boolean;
|
||||
}
|
||||
|
||||
/** Visual wrapper for timeline steps - icon, connector line, header, and content */
|
||||
export function StepContainer({
|
||||
children,
|
||||
stepIcon: StepIconComponent,
|
||||
header,
|
||||
buttonTitle,
|
||||
isExpanded = true,
|
||||
onToggle,
|
||||
collapsible = true,
|
||||
supportsCompact = false,
|
||||
isLastStep = false,
|
||||
isFirstStep = false,
|
||||
className,
|
||||
hideHeader = false,
|
||||
}: StepContainerProps) {
|
||||
const showCollapseControls = collapsible && supportsCompact && onToggle;
|
||||
|
||||
return (
|
||||
<div className={cn("flex w-full", className)}>
|
||||
<div
|
||||
className={cn("flex flex-col items-center w-9", isFirstStep && "pt-2")}
|
||||
>
|
||||
{/* Icon */}
|
||||
{!hideHeader && StepIconComponent && (
|
||||
<div className="py-1">
|
||||
<StepIconComponent className="size-4 stroke-text-02" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Connector line */}
|
||||
{!isLastStep && <div className="w-px flex-1 bg-border-01" />}
|
||||
</div>
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"w-full bg-background-tint-00",
|
||||
isLastStep && "rounded-b-12"
|
||||
)}
|
||||
>
|
||||
{!hideHeader && (
|
||||
<div className="flex items-center justify-between px-2">
|
||||
{header && (
|
||||
<Text as="p" mainUiMuted text03>
|
||||
{header}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{showCollapseControls &&
|
||||
(buttonTitle ? (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
>
|
||||
{buttonTitle}
|
||||
</Button>
|
||||
) : (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="px-2 pb-2">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default StepContainer;
|
||||
@@ -0,0 +1,116 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, JSX } from "react";
|
||||
import { Packet, StopReason } from "@/app/chat/services/streamingModels";
|
||||
import { FullChatState, RenderType, RendererResult } from "../interfaces";
|
||||
import { findRenderer } from "../renderMessageComponent";
|
||||
|
||||
/** Extended result that includes collapse state */
|
||||
export interface TimelineRendererResult extends RendererResult {
|
||||
/** Current expanded state */
|
||||
isExpanded: boolean;
|
||||
/** Toggle callback */
|
||||
onToggle: () => void;
|
||||
/** Current render type */
|
||||
renderType: RenderType;
|
||||
/** Whether this is the last step (passed through from props) */
|
||||
isLastStep: boolean;
|
||||
}
|
||||
|
||||
export interface TimelineRendererComponentProps {
|
||||
/** Packets to render */
|
||||
packets: Packet[];
|
||||
/** Chat state for rendering */
|
||||
chatState: FullChatState;
|
||||
/** Completion callback */
|
||||
onComplete: () => void;
|
||||
/** Whether to animate streaming */
|
||||
animate: boolean;
|
||||
/** Whether stop packet has been seen */
|
||||
stopPacketSeen: boolean;
|
||||
/** Reason for stopping */
|
||||
stopReason?: StopReason;
|
||||
/** Initial expanded state */
|
||||
defaultExpanded?: boolean;
|
||||
/** Whether this is the last step in the timeline (for connector line decisions) */
|
||||
isLastStep?: boolean;
|
||||
/** Children render function - receives extended result with collapse state */
|
||||
children: (result: TimelineRendererResult) => JSX.Element;
|
||||
}
|
||||
|
||||
// Custom comparison function to prevent unnecessary re-renders
|
||||
// Only re-render if meaningful changes occur
|
||||
function arePropsEqual(
|
||||
prev: TimelineRendererComponentProps,
|
||||
next: TimelineRendererComponentProps
|
||||
): boolean {
|
||||
return (
|
||||
prev.packets.length === next.packets.length &&
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.isLastStep === next.isLastStep &&
|
||||
prev.defaultExpanded === next.defaultExpanded
|
||||
// Skipping chatState (memoized upstream)
|
||||
);
|
||||
}
|
||||
|
||||
export const TimelineRendererComponent = React.memo(
|
||||
function TimelineRendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
defaultExpanded = true,
|
||||
isLastStep,
|
||||
children,
|
||||
}: TimelineRendererComponentProps) {
|
||||
const [isExpanded, setIsExpanded] = useState(defaultExpanded);
|
||||
const handleToggle = () => setIsExpanded((prev) => !prev);
|
||||
const RendererFn = findRenderer({ packets });
|
||||
const renderType = isExpanded ? RenderType.FULL : RenderType.COMPACT;
|
||||
|
||||
if (!RendererFn) {
|
||||
return children({
|
||||
icon: null,
|
||||
status: null,
|
||||
content: <></>,
|
||||
supportsCompact: false,
|
||||
isExpanded,
|
||||
onToggle: handleToggle,
|
||||
renderType,
|
||||
isLastStep: isLastStep ?? true,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={renderType}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
isLastStep={isLastStep}
|
||||
>
|
||||
{({ icon, status, content, expandedText, supportsCompact }) =>
|
||||
children({
|
||||
icon,
|
||||
status,
|
||||
content,
|
||||
expandedText,
|
||||
supportsCompact,
|
||||
isExpanded,
|
||||
onToggle: handleToggle,
|
||||
renderType,
|
||||
isLastStep: isLastStep ?? true,
|
||||
})
|
||||
}
|
||||
</RendererFn>
|
||||
);
|
||||
},
|
||||
arePropsEqual
|
||||
);
|
||||
@@ -0,0 +1,49 @@
|
||||
import React from "react";
|
||||
import { SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { UniqueTool } from "@/app/chat/message/messageComponents/timeline/hooks";
|
||||
|
||||
export interface CollapsedHeaderProps {
|
||||
uniqueTools: UniqueTool[];
|
||||
totalSteps: number;
|
||||
collapsible: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when completed + collapsed - tools summary + step count */
|
||||
export const CollapsedHeader = React.memo(function CollapsedHeader({
|
||||
uniqueTools,
|
||||
totalSteps,
|
||||
collapsible,
|
||||
onToggle,
|
||||
}: CollapsedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<div className="flex items-center gap-2">
|
||||
{uniqueTools.map((tool) => (
|
||||
<div
|
||||
key={tool.key}
|
||||
className="inline-flex items-center gap-1 rounded-08 p-1 bg-background-tint-02"
|
||||
>
|
||||
{tool.icon}
|
||||
<Text as="span" secondaryBody text04>
|
||||
{tool.name}
|
||||
</Text>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
{collapsible && (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={SvgExpand}
|
||||
aria-label="Expand timeline"
|
||||
aria-expanded={false}
|
||||
>
|
||||
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,32 @@
|
||||
import React from "react";
|
||||
import { SvgFold } from "@opal/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface ExpandedHeaderProps {
|
||||
collapsible: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when completed + expanded */
|
||||
export const ExpandedHeader = React.memo(function ExpandedHeader({
|
||||
collapsible,
|
||||
onToggle,
|
||||
}: ExpandedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text as="p" mainUiAction text03>
|
||||
Thought for some time
|
||||
</Text>
|
||||
{collapsible && (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={SvgFold}
|
||||
aria-label="Collapse timeline"
|
||||
aria-expanded={true}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,53 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { TurnGroup } from "../transformers";
|
||||
import { getToolIcon, getToolName } from "../../toolDisplayHelpers";
|
||||
|
||||
export interface ParallelStreamingHeaderProps {
|
||||
steps: TurnGroup["steps"];
|
||||
activeTab: string;
|
||||
onTabChange: (tab: string) => void;
|
||||
collapsible: boolean;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header during streaming with parallel tools - tabs only */
|
||||
export const ParallelStreamingHeader = React.memo(
|
||||
function ParallelStreamingHeader({
|
||||
steps,
|
||||
activeTab,
|
||||
onTabChange,
|
||||
collapsible,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: ParallelStreamingHeaderProps) {
|
||||
return (
|
||||
<Tabs value={activeTab} onValueChange={onTabChange}>
|
||||
<div className="flex items-center justify-between w-full gap-2">
|
||||
<Tabs.List variant="pill">
|
||||
{steps.map((step) => (
|
||||
<Tabs.Trigger key={step.key} value={step.key} variant="pill">
|
||||
<span className="flex items-center gap-1.5">
|
||||
{getToolIcon(step.packets)}
|
||||
{getToolName(step.packets)}
|
||||
</span>
|
||||
</Tabs.Trigger>
|
||||
))}
|
||||
</Tabs.List>
|
||||
{collapsible && (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Tabs>
|
||||
);
|
||||
}
|
||||
);
|
||||
@@ -0,0 +1,38 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StoppedHeaderProps {
|
||||
totalSteps: number;
|
||||
collapsible: boolean;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header when user stopped/cancelled */
|
||||
export const StoppedHeader = React.memo(function StoppedHeader({
|
||||
totalSteps,
|
||||
collapsible,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: StoppedHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text as="p" mainUiAction text03>
|
||||
Stopped Thinking
|
||||
</Text>
|
||||
{collapsible && (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
>
|
||||
{totalSteps} {totalSteps === 1 ? "step" : "steps"}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,54 @@
|
||||
import React from "react";
|
||||
import { SvgFold, SvgExpand } from "@opal/icons";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
export interface StreamingHeaderProps {
|
||||
headerText: string;
|
||||
collapsible: boolean;
|
||||
buttonTitle?: string;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
/** Header during streaming - shimmer text with current activity */
|
||||
export const StreamingHeader = React.memo(function StreamingHeader({
|
||||
headerText,
|
||||
collapsible,
|
||||
buttonTitle,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
}: StreamingHeaderProps) {
|
||||
return (
|
||||
<>
|
||||
<Text
|
||||
as="p"
|
||||
mainUiAction
|
||||
text03
|
||||
className="animate-shimmer bg-[length:200%_100%] bg-[linear-gradient(90deg,var(--shimmer-base)_10%,var(--shimmer-highlight)_40%,var(--shimmer-base)_70%)] bg-clip-text text-transparent"
|
||||
>
|
||||
{headerText}
|
||||
</Text>
|
||||
{collapsible &&
|
||||
(buttonTitle ? (
|
||||
<Button
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
rightIcon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-expanded={isExpanded}
|
||||
>
|
||||
{buttonTitle}
|
||||
</Button>
|
||||
) : (
|
||||
<IconButton
|
||||
tertiary
|
||||
onClick={onToggle}
|
||||
icon={isExpanded ? SvgFold : SvgExpand}
|
||||
aria-label={isExpanded ? "Collapse timeline" : "Expand timeline"}
|
||||
aria-expanded={isExpanded}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,14 @@
|
||||
export { StreamingHeader } from "./StreamingHeader";
|
||||
export type { StreamingHeaderProps } from "./StreamingHeader";
|
||||
|
||||
export { CollapsedHeader } from "./CollapsedHeader";
|
||||
export type { CollapsedHeaderProps } from "./CollapsedHeader";
|
||||
|
||||
export { ExpandedHeader } from "./ExpandedHeader";
|
||||
export type { ExpandedHeaderProps } from "./ExpandedHeader";
|
||||
|
||||
export { StoppedHeader } from "./StoppedHeader";
|
||||
export type { StoppedHeaderProps } from "./StoppedHeader";
|
||||
|
||||
export { ParallelStreamingHeader } from "./ParallelStreamingHeader";
|
||||
export type { ParallelStreamingHeaderProps } from "./ParallelStreamingHeader";
|
||||
@@ -0,0 +1,11 @@
|
||||
export { useTimelineExpansion } from "./useTimelineExpansion";
|
||||
export type { TimelineExpansionState } from "./useTimelineExpansion";
|
||||
|
||||
export { useTimelineMetrics } from "./useTimelineMetrics";
|
||||
export type { TimelineMetrics, UniqueTool } from "./useTimelineMetrics";
|
||||
|
||||
export { usePacketProcessor } from "./usePacketProcessor";
|
||||
export type { UsePacketProcessorResult } from "./usePacketProcessor";
|
||||
|
||||
export { useTimelineHeader } from "./useTimelineHeader";
|
||||
export type { TimelineHeaderResult } from "./useTimelineHeader";
|
||||
@@ -0,0 +1,439 @@
|
||||
import {
|
||||
Packet,
|
||||
PacketType,
|
||||
StreamingCitation,
|
||||
StopReason,
|
||||
CitationInfo,
|
||||
SearchToolDocumentsDelta,
|
||||
FetchToolDocuments,
|
||||
TopLevelBranching,
|
||||
Stop,
|
||||
SearchToolStart,
|
||||
CustomToolStart,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { CitationMap } from "@/app/chat/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import {
|
||||
isActualToolCallPacket,
|
||||
isToolPacket,
|
||||
isDisplayPacket,
|
||||
} from "@/app/chat/services/packetUtils";
|
||||
import { parseToolKey } from "@/app/chat/message/messageComponents/toolDisplayHelpers";
|
||||
|
||||
// Re-export parseToolKey for consumers that import from this module
|
||||
export { parseToolKey };
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface ProcessorState {
|
||||
nodeId: number;
|
||||
lastProcessedIndex: number;
|
||||
|
||||
// Citations
|
||||
citations: StreamingCitation[];
|
||||
seenCitationDocIds: Set<string>;
|
||||
citationMap: CitationMap;
|
||||
|
||||
// Documents
|
||||
documentMap: Map<string, OnyxDocument>;
|
||||
|
||||
// Packet grouping
|
||||
groupedPacketsMap: Map<string, Packet[]>;
|
||||
seenGroupKeys: Set<string>;
|
||||
groupKeysWithSectionEnd: Set<string>;
|
||||
expectedBranches: Map<number, number>;
|
||||
|
||||
// Pre-categorized groups (populated during packet processing)
|
||||
toolGroupKeys: Set<string>;
|
||||
displayGroupKeys: Set<string>;
|
||||
|
||||
// Unique tool names tracking (populated during packet processing)
|
||||
uniqueToolNames: Set<string>;
|
||||
|
||||
// Streaming status
|
||||
finalAnswerComing: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason: StopReason | undefined;
|
||||
|
||||
// Result arrays (built at end of processPackets)
|
||||
toolGroups: GroupedPacket[];
|
||||
potentialDisplayGroups: GroupedPacket[];
|
||||
uniqueToolNamesArray: string[];
|
||||
}
|
||||
|
||||
export interface GroupedPacket {
|
||||
turn_index: number;
|
||||
tab_index: number;
|
||||
packets: Packet[];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// State Creation
|
||||
// ============================================================================
|
||||
|
||||
export function createInitialState(nodeId: number): ProcessorState {
|
||||
return {
|
||||
nodeId,
|
||||
lastProcessedIndex: 0,
|
||||
citations: [],
|
||||
seenCitationDocIds: new Set(),
|
||||
citationMap: {},
|
||||
documentMap: new Map(),
|
||||
groupedPacketsMap: new Map(),
|
||||
seenGroupKeys: new Set(),
|
||||
groupKeysWithSectionEnd: new Set(),
|
||||
expectedBranches: new Map(),
|
||||
toolGroupKeys: new Set(),
|
||||
displayGroupKeys: new Set(),
|
||||
uniqueToolNames: new Set(),
|
||||
finalAnswerComing: false,
|
||||
stopPacketSeen: false,
|
||||
stopReason: undefined,
|
||||
toolGroups: [],
|
||||
potentialDisplayGroups: [],
|
||||
uniqueToolNamesArray: [],
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
function getGroupKey(packet: Packet): string {
|
||||
const turnIndex = packet.placement.turn_index;
|
||||
const tabIndex = packet.placement.tab_index ?? 0;
|
||||
return `${turnIndex}-${tabIndex}`;
|
||||
}
|
||||
|
||||
function injectSectionEnd(state: ProcessorState, groupKey: string): void {
|
||||
if (state.groupKeysWithSectionEnd.has(groupKey)) {
|
||||
return; // Already has SECTION_END
|
||||
}
|
||||
|
||||
const { turn_index, tab_index } = parseToolKey(groupKey);
|
||||
|
||||
const syntheticPacket: Packet = {
|
||||
placement: { turn_index, tab_index },
|
||||
obj: { type: PacketType.SECTION_END },
|
||||
};
|
||||
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
if (existingGroup) {
|
||||
existingGroup.push(syntheticPacket);
|
||||
}
|
||||
state.groupKeysWithSectionEnd.add(groupKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Content packet types that indicate a group has meaningful content to display
|
||||
*/
|
||||
const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.MESSAGE_START,
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
PacketType.REASONING_START,
|
||||
PacketType.DEEP_RESEARCH_PLAN_START,
|
||||
PacketType.RESEARCH_AGENT_START,
|
||||
]);
|
||||
|
||||
function hasContentPackets(packets: Packet[]): boolean {
|
||||
return packets.some((packet) =>
|
||||
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract tool name from a packet for unique tool tracking.
|
||||
* Returns null for non-tool packets.
|
||||
*/
|
||||
function getToolNameFromPacket(packet: Packet): string | null {
|
||||
switch (packet.obj.type) {
|
||||
case PacketType.SEARCH_TOOL_START: {
|
||||
const searchPacket = packet.obj as SearchToolStart;
|
||||
return searchPacket.is_internet_search ? "Web Search" : "Internal Search";
|
||||
}
|
||||
case PacketType.PYTHON_TOOL_START:
|
||||
return "Code Interpreter";
|
||||
case PacketType.FETCH_TOOL_START:
|
||||
return "Open URLs";
|
||||
case PacketType.CUSTOM_TOOL_START: {
|
||||
const customPacket = packet.obj as CustomToolStart;
|
||||
return customPacket.tool_name || "Custom Tool";
|
||||
}
|
||||
case PacketType.IMAGE_GENERATION_TOOL_START:
|
||||
return "Generate Image";
|
||||
case PacketType.DEEP_RESEARCH_PLAN_START:
|
||||
return "Generate plan";
|
||||
case PacketType.RESEARCH_AGENT_START:
|
||||
return "Research agent";
|
||||
case PacketType.REASONING_START:
|
||||
return "Thinking";
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Packet types that indicate final answer content is coming
|
||||
*/
|
||||
const FINAL_ANSWER_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.MESSAGE_START,
|
||||
PacketType.MESSAGE_DELTA,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_DELTA,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_DELTA,
|
||||
]);
|
||||
|
||||
// ============================================================================
|
||||
// Packet Handlers
|
||||
// ============================================================================
|
||||
|
||||
function handleTopLevelBranching(state: ProcessorState, packet: Packet): void {
|
||||
const branchingPacket = packet.obj as TopLevelBranching;
|
||||
state.expectedBranches.set(
|
||||
packet.placement.turn_index,
|
||||
branchingPacket.num_parallel_branches
|
||||
);
|
||||
}
|
||||
|
||||
function handleTurnTransition(state: ProcessorState, packet: Packet): void {
|
||||
const currentTurnIndex = packet.placement.turn_index;
|
||||
|
||||
// Get all previous turn indices from seen group keys
|
||||
const previousTurnIndices = new Set(
|
||||
Array.from(state.seenGroupKeys).map((key) => parseToolKey(key).turn_index)
|
||||
);
|
||||
|
||||
const isNewTurnIndex = !previousTurnIndices.has(currentTurnIndex);
|
||||
|
||||
// If we see a new turn_index (not just tab_index), inject SECTION_END for previous groups
|
||||
if (isNewTurnIndex && state.seenGroupKeys.size > 0) {
|
||||
state.seenGroupKeys.forEach((prevGroupKey) => {
|
||||
if (!state.groupKeysWithSectionEnd.has(prevGroupKey)) {
|
||||
injectSectionEnd(state, prevGroupKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleCitationPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type !== PacketType.CITATION_INFO) {
|
||||
return;
|
||||
}
|
||||
|
||||
const citationInfo = packet.obj as CitationInfo;
|
||||
|
||||
// Add to citation map immediately for rendering
|
||||
state.citationMap[citationInfo.citation_number] = citationInfo.document_id;
|
||||
|
||||
// Also add to citations array for CitedSourcesToggle (deduplicated)
|
||||
if (!state.seenCitationDocIds.has(citationInfo.document_id)) {
|
||||
state.seenCitationDocIds.add(citationInfo.document_id);
|
||||
state.citations.push({
|
||||
citation_num: citationInfo.citation_number,
|
||||
document_id: citationInfo.document_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDocumentPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA) {
|
||||
const docDelta = packet.obj as SearchToolDocumentsDelta;
|
||||
if (docDelta.documents) {
|
||||
for (const doc of docDelta.documents) {
|
||||
if (doc.document_id) {
|
||||
state.documentMap.set(doc.document_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (packet.obj.type === PacketType.FETCH_TOOL_DOCUMENTS) {
|
||||
const fetchDocuments = packet.obj as FetchToolDocuments;
|
||||
if (fetchDocuments.documents) {
|
||||
for (const doc of fetchDocuments.documents) {
|
||||
if (doc.document_id) {
|
||||
state.documentMap.set(doc.document_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamingStatusPacket(
|
||||
state: ProcessorState,
|
||||
packet: Packet
|
||||
): void {
|
||||
// Check if final answer is coming
|
||||
if (FINAL_ANSWER_PACKET_TYPES_SET.has(packet.obj.type as PacketType)) {
|
||||
state.finalAnswerComing = true;
|
||||
}
|
||||
}
|
||||
|
||||
function handleStopPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (packet.obj.type !== PacketType.STOP || state.stopPacketSeen) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.stopPacketSeen = true;
|
||||
|
||||
// Extract and store the stop reason
|
||||
const stopPacket = packet.obj as Stop;
|
||||
state.stopReason = stopPacket.stop_reason;
|
||||
|
||||
// Inject SECTION_END for all group keys that don't have one
|
||||
state.seenGroupKeys.forEach((groupKey) => {
|
||||
if (!state.groupKeysWithSectionEnd.has(groupKey)) {
|
||||
injectSectionEnd(state, groupKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleToolAfterMessagePacket(
|
||||
state: ProcessorState,
|
||||
packet: Packet
|
||||
): void {
|
||||
// Handles case where we get a Message packet from Claude, and then tool
|
||||
// calling packets. We use isActualToolCallPacket instead of isToolPacket
|
||||
// to exclude reasoning packets - reasoning is just the model thinking,
|
||||
// not an actual tool call that would produce new content.
|
||||
if (
|
||||
state.finalAnswerComing &&
|
||||
!state.stopPacketSeen &&
|
||||
isActualToolCallPacket(packet)
|
||||
) {
|
||||
state.finalAnswerComing = false;
|
||||
}
|
||||
}
|
||||
|
||||
function addPacketToGroup(
|
||||
state: ProcessorState,
|
||||
packet: Packet,
|
||||
groupKey: string
|
||||
): void {
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
if (existingGroup) {
|
||||
existingGroup.push(packet);
|
||||
} else {
|
||||
state.groupedPacketsMap.set(groupKey, [packet]);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Processing Function
|
||||
// ============================================================================
|
||||
|
||||
function processPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (!packet) return;
|
||||
|
||||
// Handle TopLevelBranching packets - these tell us how many parallel branches to expect
|
||||
if (packet.obj.type === PacketType.TOP_LEVEL_BRANCHING) {
|
||||
handleTopLevelBranching(state, packet);
|
||||
// Don't add this packet to any group, it's just metadata
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle turn transitions (inject SECTION_END for previous groups)
|
||||
handleTurnTransition(state, packet);
|
||||
|
||||
// Track group key
|
||||
const groupKey = getGroupKey(packet);
|
||||
state.seenGroupKeys.add(groupKey);
|
||||
|
||||
// Track SECTION_END and ERROR packets (both indicate completion)
|
||||
if (
|
||||
packet.obj.type === PacketType.SECTION_END ||
|
||||
packet.obj.type === PacketType.ERROR
|
||||
) {
|
||||
state.groupKeysWithSectionEnd.add(groupKey);
|
||||
}
|
||||
|
||||
// Check if this is the first packet in the group (before adding)
|
||||
const existingGroup = state.groupedPacketsMap.get(groupKey);
|
||||
const isFirstPacket = !existingGroup;
|
||||
|
||||
// Add packet to group
|
||||
addPacketToGroup(state, packet, groupKey);
|
||||
|
||||
// Categorize on first packet of each group
|
||||
if (isFirstPacket) {
|
||||
if (isToolPacket(packet, false)) {
|
||||
state.toolGroupKeys.add(groupKey);
|
||||
// Track unique tool name
|
||||
const toolName = getToolNameFromPacket(packet);
|
||||
if (toolName) {
|
||||
state.uniqueToolNames.add(toolName);
|
||||
}
|
||||
}
|
||||
if (isDisplayPacket(packet)) {
|
||||
state.displayGroupKeys.add(groupKey);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle specific packet types
|
||||
handleCitationPacket(state, packet);
|
||||
handleDocumentPacket(state, packet);
|
||||
handleStreamingStatusPacket(state, packet);
|
||||
handleStopPacket(state, packet);
|
||||
handleToolAfterMessagePacket(state, packet);
|
||||
}
|
||||
|
||||
export function processPackets(
|
||||
state: ProcessorState,
|
||||
rawPackets: Packet[]
|
||||
): ProcessorState {
|
||||
// Handle reset (packets array shrunk - upstream replaced with shorter list)
|
||||
if (state.lastProcessedIndex > rawPackets.length) {
|
||||
state = createInitialState(state.nodeId);
|
||||
}
|
||||
|
||||
// Process only new packets
|
||||
for (let i = state.lastProcessedIndex; i < rawPackets.length; i++) {
|
||||
const packet = rawPackets[i];
|
||||
if (packet) {
|
||||
processPacket(state, packet);
|
||||
}
|
||||
}
|
||||
|
||||
state.lastProcessedIndex = rawPackets.length;
|
||||
|
||||
// Build result arrays after processing
|
||||
state.toolGroups = buildGroupsFromKeys(state, state.toolGroupKeys);
|
||||
state.potentialDisplayGroups = buildGroupsFromKeys(
|
||||
state,
|
||||
state.displayGroupKeys
|
||||
);
|
||||
state.uniqueToolNamesArray = Array.from(state.uniqueToolNames);
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build GroupedPacket array from a set of group keys.
|
||||
* Filters to only include groups with meaningful content and sorts by turn/tab index.
|
||||
*/
|
||||
function buildGroupsFromKeys(
|
||||
state: ProcessorState,
|
||||
keys: Set<string>
|
||||
): GroupedPacket[] {
|
||||
return Array.from(keys)
|
||||
.map((key) => {
|
||||
const { turn_index, tab_index } = parseToolKey(key);
|
||||
const packets = state.groupedPacketsMap.get(key);
|
||||
// Spread to create new array reference - ensures React detects changes for re-renders
|
||||
return packets ? { turn_index, tab_index, packets: [...packets] } : null;
|
||||
})
|
||||
.filter(
|
||||
(g): g is GroupedPacket => g !== null && hasContentPackets(g.packets)
|
||||
)
|
||||
.sort((a, b) => {
|
||||
if (a.turn_index !== b.turn_index) {
|
||||
return a.turn_index - b.turn_index;
|
||||
}
|
||||
return a.tab_index - b.tab_index;
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
import { useRef, useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Packet,
|
||||
StreamingCitation,
|
||||
StopReason,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { CitationMap } from "@/app/chat/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import {
|
||||
ProcessorState,
|
||||
GroupedPacket,
|
||||
createInitialState,
|
||||
processPackets,
|
||||
} from "@/app/chat/message/messageComponents/timeline/hooks/packetProcessor";
|
||||
import {
|
||||
transformPacketGroups,
|
||||
groupStepsByTurn,
|
||||
TurnGroup,
|
||||
} from "@/app/chat/message/messageComponents/timeline/transformers";
|
||||
|
||||
export interface UsePacketProcessorResult {
|
||||
// Data
|
||||
toolGroups: GroupedPacket[];
|
||||
displayGroups: GroupedPacket[];
|
||||
toolTurnGroups: TurnGroup[];
|
||||
citations: StreamingCitation[];
|
||||
citationMap: CitationMap;
|
||||
documentMap: Map<string, OnyxDocument>;
|
||||
|
||||
// Status (derived from packets)
|
||||
stopPacketSeen: boolean;
|
||||
stopReason: StopReason | undefined;
|
||||
hasSteps: boolean;
|
||||
expectedBranchesPerTurn: Map<number, number>;
|
||||
uniqueToolNames: string[];
|
||||
|
||||
// Completion: stopPacketSeen && renderComplete
|
||||
isComplete: boolean;
|
||||
|
||||
// Callbacks
|
||||
onRenderComplete: () => void;
|
||||
markAllToolsDisplayed: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for processing streaming packets in AgentMessage.
|
||||
*
|
||||
* Architecture:
|
||||
* - Processor state in ref: incremental processing, synchronous, no double render
|
||||
* - Only true UI state: renderComplete (set by callback), forceShowAnswer (override)
|
||||
* - Everything else derived from packets
|
||||
*
|
||||
* Key insight: finalAnswerComing and stopPacketSeen are DERIVED from packets,
|
||||
* not independent state. Only renderComplete needs useState.
|
||||
*/
|
||||
export function usePacketProcessor(
|
||||
rawPackets: Packet[],
|
||||
nodeId: number
|
||||
): UsePacketProcessorResult {
|
||||
// Processor in ref: incremental, synchronous, no double render
|
||||
const stateRef = useRef<ProcessorState>(createInitialState(nodeId));
|
||||
|
||||
// Only TRUE UI state: "has renderer finished?"
|
||||
const [renderComplete, setRenderComplete] = useState(false);
|
||||
|
||||
// Optional override to force showing answer
|
||||
const [forceShowAnswer, setForceShowAnswer] = useState(false);
|
||||
|
||||
// Reset on nodeId change
|
||||
if (stateRef.current.nodeId !== nodeId) {
|
||||
stateRef.current = createInitialState(nodeId);
|
||||
setRenderComplete(false);
|
||||
setForceShowAnswer(false);
|
||||
}
|
||||
|
||||
// Track for transition detection
|
||||
const prevLastProcessed = stateRef.current.lastProcessedIndex;
|
||||
const prevFinalAnswerComing = stateRef.current.finalAnswerComing;
|
||||
|
||||
// Detect stream reset (packets shrunk)
|
||||
if (prevLastProcessed > rawPackets.length) {
|
||||
stateRef.current = createInitialState(nodeId);
|
||||
setRenderComplete(false);
|
||||
setForceShowAnswer(false);
|
||||
}
|
||||
|
||||
// Process packets synchronously (incremental) - only if new packets arrived
|
||||
if (rawPackets.length > stateRef.current.lastProcessedIndex) {
|
||||
stateRef.current = processPackets(stateRef.current, rawPackets);
|
||||
}
|
||||
|
||||
// Reset renderComplete on tool-after-message transition
|
||||
if (prevFinalAnswerComing && !stateRef.current.finalAnswerComing) {
|
||||
setRenderComplete(false);
|
||||
}
|
||||
|
||||
// Access state directly (result arrays are built in processPackets)
|
||||
const state = stateRef.current;
|
||||
|
||||
// Derive displayGroups (not state!)
|
||||
const effectiveFinalAnswerComing = state.finalAnswerComing || forceShowAnswer;
|
||||
const displayGroups = useMemo(() => {
|
||||
if (effectiveFinalAnswerComing || state.toolGroups.length === 0) {
|
||||
return state.potentialDisplayGroups;
|
||||
}
|
||||
return [];
|
||||
}, [
|
||||
effectiveFinalAnswerComing,
|
||||
state.toolGroups.length,
|
||||
state.potentialDisplayGroups,
|
||||
]);
|
||||
|
||||
// Transform toolGroups to timeline format
|
||||
const toolTurnGroups = useMemo(() => {
|
||||
const allSteps = transformPacketGroups(state.toolGroups);
|
||||
return groupStepsByTurn(allSteps);
|
||||
}, [state.toolGroups]);
|
||||
|
||||
// Callback reads from ref: always current value, no ref needed in component
|
||||
const onRenderComplete = useCallback(() => {
|
||||
if (stateRef.current.finalAnswerComing) {
|
||||
setRenderComplete(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const markAllToolsDisplayed = useCallback(() => {
|
||||
setForceShowAnswer(true);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
// Data
|
||||
toolGroups: state.toolGroups,
|
||||
displayGroups,
|
||||
toolTurnGroups,
|
||||
citations: state.citations,
|
||||
citationMap: state.citationMap,
|
||||
documentMap: state.documentMap,
|
||||
|
||||
// Status (derived from packets)
|
||||
stopPacketSeen: state.stopPacketSeen,
|
||||
stopReason: state.stopReason,
|
||||
hasSteps: toolTurnGroups.length > 0,
|
||||
expectedBranchesPerTurn: state.expectedBranches,
|
||||
uniqueToolNames: state.uniqueToolNamesArray,
|
||||
|
||||
// Completion: stopPacketSeen && renderComplete
|
||||
isComplete: state.stopPacketSeen && renderComplete,
|
||||
|
||||
// Callbacks
|
||||
onRenderComplete,
|
||||
markAllToolsDisplayed,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { TurnGroup } from "../transformers";
|
||||
|
||||
export interface TimelineExpansionState {
|
||||
isExpanded: boolean;
|
||||
handleToggle: () => void;
|
||||
parallelActiveTab: string;
|
||||
setParallelActiveTab: (tab: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages expansion state for the timeline.
|
||||
* Auto-collapses when streaming completes and syncs parallel tab selection.
|
||||
*/
|
||||
export function useTimelineExpansion(
|
||||
stopPacketSeen: boolean,
|
||||
lastTurnGroup: TurnGroup | undefined
|
||||
): TimelineExpansionState {
|
||||
const [isExpanded, setIsExpanded] = useState(!stopPacketSeen);
|
||||
const [parallelActiveTab, setParallelActiveTab] = useState<string>("");
|
||||
|
||||
const handleToggle = useCallback(() => {
|
||||
setIsExpanded((prev) => !prev);
|
||||
}, []);
|
||||
|
||||
// Auto-collapse when streaming completes
|
||||
useEffect(() => {
|
||||
if (stopPacketSeen) {
|
||||
setIsExpanded(false);
|
||||
}
|
||||
}, [stopPacketSeen]);
|
||||
|
||||
// Sync active tab when parallel turn group changes
|
||||
useEffect(() => {
|
||||
if (lastTurnGroup?.isParallel && lastTurnGroup.steps.length > 0) {
|
||||
const validTabs = lastTurnGroup.steps.map((s) => s.key);
|
||||
const firstStep = lastTurnGroup.steps[0];
|
||||
if (firstStep && !validTabs.includes(parallelActiveTab)) {
|
||||
setParallelActiveTab(firstStep.key);
|
||||
}
|
||||
}
|
||||
}, [lastTurnGroup, parallelActiveTab]);
|
||||
|
||||
return {
|
||||
isExpanded,
|
||||
handleToggle,
|
||||
parallelActiveTab,
|
||||
setParallelActiveTab,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
import { useMemo } from "react";
|
||||
import { TurnGroup } from "../transformers";
|
||||
import {
|
||||
PacketType,
|
||||
SearchToolPacket,
|
||||
StopReason,
|
||||
CustomToolStart,
|
||||
} from "@/app/chat/services/streamingModels";
|
||||
import { constructCurrentSearchState } from "@/app/chat/message/messageComponents/timeline/renderers/search/searchStateUtils";
|
||||
|
||||
export interface TimelineHeaderResult {
|
||||
headerText: string;
|
||||
hasPackets: boolean;
|
||||
userStopped: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook that determines timeline header state based on current activity.
|
||||
* Returns header text, whether there are packets, and whether user stopped.
|
||||
*/
|
||||
export function useTimelineHeader(
|
||||
turnGroups: TurnGroup[],
|
||||
stopReason?: StopReason
|
||||
): TimelineHeaderResult {
|
||||
return useMemo(() => {
|
||||
const hasPackets = turnGroups.length > 0;
|
||||
const userStopped = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
if (!hasPackets) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
// Get the last (current) turn group
|
||||
const currentTurn = turnGroups[turnGroups.length - 1];
|
||||
if (!currentTurn) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const currentStep = currentTurn.steps[0];
|
||||
if (!currentStep?.packets?.length) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const firstPacket = currentStep.packets[0];
|
||||
if (!firstPacket) {
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
const packetType = firstPacket.obj.type;
|
||||
|
||||
// Determine header based on packet type
|
||||
if (packetType === PacketType.SEARCH_TOOL_START) {
|
||||
const searchState = constructCurrentSearchState(
|
||||
currentStep.packets as SearchToolPacket[]
|
||||
);
|
||||
const headerText = searchState.isInternetSearch
|
||||
? "Searching web"
|
||||
: "Searching internal documents";
|
||||
return { headerText, hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.FETCH_TOOL_START) {
|
||||
return { headerText: "Opening URLs", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.PYTHON_TOOL_START) {
|
||||
return { headerText: "Executing code", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.IMAGE_GENERATION_TOOL_START) {
|
||||
return { headerText: "Generating images", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.CUSTOM_TOOL_START) {
|
||||
const toolName = (firstPacket.obj as CustomToolStart).tool_name;
|
||||
return {
|
||||
headerText: toolName ? `Executing ${toolName}` : "Executing tool",
|
||||
hasPackets,
|
||||
userStopped,
|
||||
};
|
||||
}
|
||||
|
||||
if (packetType === PacketType.REASONING_START) {
|
||||
return { headerText: "Thinking", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.DEEP_RESEARCH_PLAN_START) {
|
||||
return { headerText: "Generating plan", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
if (packetType === PacketType.RESEARCH_AGENT_START) {
|
||||
return { headerText: "Researching", hasPackets, userStopped };
|
||||
}
|
||||
|
||||
return { headerText: "Thinking...", hasPackets, userStopped };
|
||||
}, [turnGroups, stopReason]);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user