mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-06 07:22:42 +00:00
Compare commits
49 Commits
add-featur
...
v2.11.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afc163d7e5 | ||
|
|
0f214ca190 | ||
|
|
422ca91edc | ||
|
|
fe287eebb6 | ||
|
|
3f8ef8b465 | ||
|
|
ed46504a1a | ||
|
|
7a24b34516 | ||
|
|
7a7ffa9051 | ||
|
|
3053ab518c | ||
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b | ||
|
|
0594fd17de | ||
|
|
fded81dc28 |
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
39
.vscode/launch.json
vendored
39
.vscode/launch.json
vendored
@@ -149,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -587,6 +605,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -42,7 +42,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -124,6 +126,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -153,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -341,6 +352,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -423,6 +435,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
# NOTE: These strings must be formatted in the same way as the output of
|
||||
# DocumentAccess::to_acl.
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -28,8 +28,8 @@ of "minimum value clipping".
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
contents. If there are lots of updates, this may miss.
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_opensearch_filtered_access_control_list(
|
||||
access: DocumentAccess,
|
||||
) -> list[str]:
|
||||
"""Generates an access control list with PUBLIC_DOC_PAT removed.
|
||||
|
||||
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
|
||||
"""
|
||||
access_control_list = access.to_acl()
|
||||
access_control_list.discard(PUBLIC_DOC_PAT)
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -152,10 +166,9 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
chunk.access
|
||||
),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
@@ -578,8 +591,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
|
||||
generate_opensearch_filtered_access_control_list(
|
||||
update_request.access
|
||||
)
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
@@ -625,13 +640,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
# TODO(andrei): Add a param for whether to retrieve hidden docs.
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
@@ -646,6 +659,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
@@ -672,9 +687,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
@@ -688,6 +700,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
|
||||
@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
def serialize_datetime_fields_to_epoch_seconds(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
Serializes datetime fields to seconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
return int(value.timestamp())
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses seconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
@@ -354,11 +353,9 @@ class DocumentSchema:
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
"format": "epoch_second",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
@@ -366,14 +363,21 @@ class DocumentSchema:
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
# should have no effect on queries.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
# If a user's access set contains at least one entry from this
|
||||
# set, the user should be able to retrieve this document. This
|
||||
# only applies if public is set to false; public non-hidden
|
||||
# documents are always visible to anyone in a given tenancy
|
||||
# regardless of this field.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
# Whether the doc is hidden from search results.
|
||||
# Should clobber all other access search filters, namely
|
||||
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
|
||||
# search implementations to guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
@@ -447,7 +451,6 @@ class DocumentSchema:
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
@@ -91,6 +106,11 @@ assert (
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
@@ -103,6 +123,8 @@ class DocumentQuery:
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -120,6 +142,8 @@ class DocumentQuery:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the document retrieval query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
@@ -136,28 +160,21 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
max_chunk_size=max_chunk_size,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
@@ -195,15 +212,22 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
# Delete hidden docs too.
|
||||
include_hidden=True,
|
||||
access_control_list=None,
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=document_id,
|
||||
)
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
@@ -217,19 +241,25 @@ class DocumentQuery:
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
NOTE: This query can be directly supplied to the OpenSearch client, but
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
num_candidates: The number of neighbors to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the hybrid search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
@@ -243,31 +273,47 @@ class DocumentQuery:
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
@@ -294,7 +340,8 @@ class DocumentQuery:
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -305,6 +352,7 @@ class DocumentQuery:
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the title.
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -313,6 +361,7 @@ class DocumentQuery:
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the content.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -322,36 +371,273 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
# Either fuzzy match on the analyzed title (boosted 2x), or
|
||||
# exact match on exact title keywords (no OpenSearch
|
||||
# analysis done on the title). See
|
||||
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
# Returns the score of the best match of the fields above.
|
||||
# See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
# Fuzzy match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
# Exact match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
include_hidden: bool,
|
||||
access_control_list: list[str] | None,
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
max_chunk_size: int | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns filters to be passed into the "filter" key of a search query.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
access_control_list: Access control list for the documents to
|
||||
retrieve. If None, there is no restriction on the documents that
|
||||
can be retrieved. If not None, only public documents can be
|
||||
retrieved, or non-public documents where at least one acl
|
||||
provided here is present in the document's acl list.
|
||||
source_types: If supplied, only documents of one of these source
|
||||
types will be retrieved.
|
||||
tags: If supplied, only documents with an entry in their metadata
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
updated time, we assume some default age of
|
||||
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
|
||||
updated.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
max_chunk_size: The type of chunk to retrieve, specified by the
|
||||
maximum number of tokens it can hold. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
)
|
||||
return tag_filter
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
|
||||
}
|
||||
}
|
||||
)
|
||||
if time_cutoff < datetime.now(timezone.utc) - timedelta(
|
||||
days=ASSUMED_DOCUMENT_AGE_DAYS
|
||||
):
|
||||
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
|
||||
# ago, we include documents which have no
|
||||
# LAST_UPDATED_FIELD_NAME value.
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
|
||||
}
|
||||
}
|
||||
)
|
||||
return time_cutoff_filter
|
||||
|
||||
def _get_chunk_index_filter(
|
||||
min_chunk_index: int | None, max_chunk_index: int | None
|
||||
) -> dict[str, Any]:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
return range_clause
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
|
||||
|
||||
if access_control_list is not None:
|
||||
# If an access control list is provided, the caller can only
|
||||
# retrieve public documents, and non-public documents where at least
|
||||
# one acl provided here is present in the document's acl list. If
|
||||
# there is explicitly no list provided, we make no restrictions on
|
||||
# the documents that can be retrieved.
|
||||
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
|
||||
|
||||
if source_types:
|
||||
# If at least one source type is provided, the caller will only
|
||||
# retrieve documents whose source type is present in this input
|
||||
# list.
|
||||
filter_clauses.append(_get_source_type_filter(source_types))
|
||||
|
||||
if tags:
|
||||
# If at least one tag is provided, the caller will only retrieve
|
||||
# documents where at least one tag provided here is present in the
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
if document_sets:
|
||||
# If at least one document set is provided, the caller will only
|
||||
# retrieve documents where at least one document set provided here
|
||||
# is present in the document's document sets list.
|
||||
filter_clauses.append(_get_document_set_filter(document_sets))
|
||||
|
||||
if user_file_ids:
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs. Note that these IDs correspond to Onyx documents whereas
|
||||
# the entries retrieved from the document index correspond to Onyx
|
||||
# document chunks.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
# cutoff. For documents which do not have a value for
|
||||
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
|
||||
# purposes of time cutoff.
|
||||
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
filter_clauses.append(
|
||||
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
if max_chunk_size is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
return filter_clauses
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
@@ -378,4 +664,5 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,65 +1,270 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -32,6 +32,7 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -45,6 +45,7 @@ class RedisConnectorPrune:
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
|
||||
@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -84,7 +84,8 @@ def patch_document_set(
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
@@ -125,7 +126,8 @@ def delete_document_set(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
object_is_public=document_set.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
@router.get("/", tags=PUBLIC_API_TAGS)
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
|
||||
|
||||
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
|
||||
async def get_auth_type() -> AuthTypeResponse:
|
||||
user_count = await get_user_count()
|
||||
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
|
||||
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
|
||||
# associated with until they auth.
|
||||
has_users = True
|
||||
if AUTH_TYPE != AuthType.CLOUD:
|
||||
user_count = await get_user_count()
|
||||
has_users = user_count > 0
|
||||
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
password_min_length=PASSWORD_MIN_LENGTH,
|
||||
has_users=user_count > 0,
|
||||
has_users=has_users,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
|
||||
# Determine stop reason - check if message indicates user cancelled
|
||||
stop_reason: str | None = None
|
||||
if chat_message.message:
|
||||
if "Generation was stopped" in chat_message.message:
|
||||
if "generation was stopped" in chat_message.message.lower():
|
||||
stop_reason = "user_cancelled"
|
||||
|
||||
# Add overall stop packet at the end
|
||||
|
||||
@@ -573,7 +573,7 @@ mcp==1.25.0
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==0.8.4
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
|
||||
@@ -298,7 +298,7 @@ numpy==2.4.1
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.4.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -191,6 +191,18 @@ autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Listens for Discord messages and responds with answers
|
||||
# for all guilds/channels that the OnyxBot has been added to.
|
||||
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
|
||||
[program:discord_bot]
|
||||
command=python onyx/onyxbot/discord/client.py
|
||||
stdout_logfile=/var/log/discord_bot.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Pushes all logs from the above programs to stdout
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
@@ -206,6 +218,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_user_file_processing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/discord_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
/var/log/mcp_server.log
|
||||
/var/log/mcp_server.err.log
|
||||
|
||||
@@ -14,6 +14,10 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
# Counter for generating unique file IDs in mock file store
|
||||
_mock_file_id_counter = 0
|
||||
|
||||
|
||||
def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
"""Ensure a default LLM provider exists for tests that exercise chat flows."""
|
||||
|
||||
@@ -80,11 +84,34 @@ def mock_vespa_query() -> Iterator[None]:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store() -> Iterator[None]:
|
||||
"""Mock the file store to avoid S3/storage dependencies in tests."""
|
||||
global _mock_file_id_counter
|
||||
|
||||
def _mock_save_file(*args: Any, **kwargs: Any) -> str:
|
||||
global _mock_file_id_counter
|
||||
_mock_file_id_counter += 1
|
||||
# Return a predictable file ID for tests
|
||||
return "123"
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_file.side_effect = _mock_save_file
|
||||
mock_store.initialize.return_value = None
|
||||
|
||||
with patch(
|
||||
"onyx.file_store.utils.get_default_file_store",
|
||||
return_value=mock_store,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_deps(
|
||||
mock_nlp_embeddings_post: None,
|
||||
mock_gpu_status: None,
|
||||
mock_vespa_query: None,
|
||||
mock_file_store: None,
|
||||
) -> Iterator[None]:
|
||||
"""Convenience fixture to enable all common external dependency mocks."""
|
||||
yield
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
|
||||
|
||||
def assert_answer_stream_part_correct(
|
||||
received: AnswerStreamPart, expected: AnswerStreamPart
|
||||
) -> None:
|
||||
assert isinstance(received, type(expected))
|
||||
|
||||
if isinstance(received, Packet):
|
||||
r_packet = cast(Packet, received)
|
||||
e_packet = cast(Packet, expected)
|
||||
|
||||
assert r_packet.placement == e_packet.placement
|
||||
|
||||
if isinstance(r_packet.obj, SearchToolDocumentsDelta):
|
||||
assert isinstance(e_packet.obj, SearchToolDocumentsDelta)
|
||||
assert is_search_tool_document_delta_equal(r_packet.obj, e_packet.obj)
|
||||
return
|
||||
elif isinstance(r_packet.obj, OpenUrlDocuments):
|
||||
assert isinstance(e_packet.obj, OpenUrlDocuments)
|
||||
assert is_open_url_documents_equal(r_packet.obj, e_packet.obj)
|
||||
return
|
||||
elif isinstance(r_packet.obj, AgentResponseStart):
|
||||
assert isinstance(e_packet.obj, AgentResponseStart)
|
||||
assert is_agent_response_start_equal(r_packet.obj, e_packet.obj)
|
||||
return
|
||||
elif isinstance(r_packet.obj, ImageGenerationFinal):
|
||||
assert isinstance(e_packet.obj, ImageGenerationFinal)
|
||||
assert is_image_generation_final_equal(r_packet.obj, e_packet.obj)
|
||||
return
|
||||
|
||||
assert r_packet.obj == e_packet.obj
|
||||
elif isinstance(received, MessageResponseIDInfo):
|
||||
# We're not going to make assumptions about what the user id / assistant id should be
|
||||
# So just return
|
||||
return
|
||||
elif isinstance(received, CreateChatSessionID):
|
||||
# Don't worry about same session ids
|
||||
return
|
||||
else:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
|
||||
def _are_search_docs_equal(
|
||||
received: list[SearchDoc],
|
||||
expected: list[SearchDoc],
|
||||
) -> bool:
|
||||
"""
|
||||
What we care about:
|
||||
- All documents are present (order does not)
|
||||
- Expected document_id, link, blurb, source_type and hidden
|
||||
"""
|
||||
if len(received) != len(expected):
|
||||
return False
|
||||
|
||||
received.sort(key=lambda x: x.document_id)
|
||||
expected.sort(key=lambda x: x.document_id)
|
||||
|
||||
for received_document, expected_document in zip(received, expected):
|
||||
if received_document.document_id != expected_document.document_id:
|
||||
return False
|
||||
if received_document.link != expected_document.link:
|
||||
return False
|
||||
if received_document.blurb != expected_document.blurb:
|
||||
return False
|
||||
if received_document.source_type != expected_document.source_type:
|
||||
return False
|
||||
if received_document.hidden != expected_document.hidden:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_search_tool_document_delta_equal(
|
||||
received: SearchToolDocumentsDelta,
|
||||
expected: SearchToolDocumentsDelta,
|
||||
) -> bool:
|
||||
"""
|
||||
What we care about:
|
||||
- All documents are present (order does not)
|
||||
- Expected document_id, link, blurb, source_type and hidden
|
||||
"""
|
||||
received_documents = received.documents
|
||||
expected_documents = expected.documents
|
||||
|
||||
return _are_search_docs_equal(received_documents, expected_documents)
|
||||
|
||||
|
||||
def is_open_url_documents_equal(
|
||||
received: OpenUrlDocuments,
|
||||
expected: OpenUrlDocuments,
|
||||
) -> bool:
|
||||
"""
|
||||
What we care about:
|
||||
- All documents are present (order does not)
|
||||
- Expected document_id, link, blurb, source_type and hidden
|
||||
"""
|
||||
received_documents = received.documents
|
||||
expected_documents = expected.documents
|
||||
|
||||
return _are_search_docs_equal(received_documents, expected_documents)
|
||||
|
||||
|
||||
def is_agent_response_start_equal(
|
||||
received: AgentResponseStart,
|
||||
expected: AgentResponseStart,
|
||||
) -> bool:
|
||||
"""
|
||||
What we care about:
|
||||
- All documents are present (order does not)
|
||||
- Expected document_id, link, blurb, source_type and hidden
|
||||
"""
|
||||
received_documents = received.final_documents
|
||||
expected_documents = expected.final_documents
|
||||
|
||||
if received_documents is None and expected_documents is None:
|
||||
return True
|
||||
if not received_documents or not expected_documents:
|
||||
return False
|
||||
|
||||
return _are_search_docs_equal(received_documents, expected_documents)
|
||||
|
||||
|
||||
def is_image_generation_final_equal(
|
||||
received: ImageGenerationFinal,
|
||||
expected: ImageGenerationFinal,
|
||||
) -> bool:
|
||||
"""
|
||||
What we care about:
|
||||
- Number of images are the same
|
||||
- On each image, url and file_id are aligned such that url=/api/chat/file/{file_id}
|
||||
- Revised prompt is expected
|
||||
- Shape is expected
|
||||
"""
|
||||
if len(received.images) != len(expected.images):
|
||||
return False
|
||||
|
||||
for received_image, expected_image in zip(received.images, expected.images):
|
||||
if received_image.url != f"/api/chat/file/{received_image.file_id}":
|
||||
return False
|
||||
if received_image.revised_prompt != expected_image.revised_prompt:
|
||||
return False
|
||||
if received_image.shape != expected_image.shape:
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from tests.external_dependency_unit.answer.stream_test_assertions import (
|
||||
assert_answer_stream_part_correct,
|
||||
)
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import (
|
||||
create_packet_with_agent_response_delta,
|
||||
)
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import (
|
||||
create_packet_with_reasoning_delta,
|
||||
)
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_placement
|
||||
from tests.external_dependency_unit.mock_llm import LLMResponse
|
||||
from tests.external_dependency_unit.mock_llm import MockLLMController
|
||||
|
||||
|
||||
class StreamTestBuilder:
|
||||
def __init__(self, llm_controller: MockLLMController) -> None:
|
||||
self._llm_controller = llm_controller
|
||||
|
||||
# List of (expected_packet, forward_count) tuples
|
||||
self._expected_packets_queue: list[tuple[Packet, int]] = []
|
||||
|
||||
def add_response(self, response: LLMResponse) -> StreamTestBuilder:
|
||||
self._llm_controller.add_response(response)
|
||||
|
||||
return self
|
||||
|
||||
def add_responses_together(self, *responses: LLMResponse) -> StreamTestBuilder:
|
||||
"""Add multiple responses that should be emitted together in the same tick."""
|
||||
self._llm_controller.add_responses_together(*responses)
|
||||
|
||||
return self
|
||||
|
||||
def expect(
|
||||
self, expected_pkt: Packet, forward: int | bool = True
|
||||
) -> StreamTestBuilder:
|
||||
"""
|
||||
Add an expected packet to the queue.
|
||||
|
||||
Args:
|
||||
expected_pkt: The packet to expect
|
||||
forward: Number of tokens to forward before expecting this packet.
|
||||
True = 1 token, False = 0 tokens, int = that many tokens.
|
||||
"""
|
||||
forward_count = 1 if forward is True else (0 if forward is False else forward)
|
||||
self._expected_packets_queue.append((expected_pkt, forward_count))
|
||||
|
||||
return self
|
||||
|
||||
def expect_packets(
|
||||
self, packets: list[Packet], forward: int | bool = True
|
||||
) -> StreamTestBuilder:
|
||||
"""
|
||||
Add multiple expected packets to the queue.
|
||||
|
||||
Args:
|
||||
packets: List of packets to expect
|
||||
forward: Number of tokens to forward before expecting EACH packet.
|
||||
True = 1 token per packet, False = 0 tokens, int = that many tokens per packet.
|
||||
"""
|
||||
forward_count = 1 if forward is True else (0 if forward is False else forward)
|
||||
for pkt in packets:
|
||||
self._expected_packets_queue.append((pkt, forward_count))
|
||||
|
||||
return self
|
||||
|
||||
def expect_reasoning(
|
||||
self,
|
||||
reasoning_tokens: list[str],
|
||||
turn_index: int,
|
||||
) -> StreamTestBuilder:
|
||||
return (
|
||||
self.expect(
|
||||
Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
)
|
||||
.expect_packets(
|
||||
[
|
||||
create_packet_with_reasoning_delta(token, turn_index)
|
||||
for token in reasoning_tokens
|
||||
]
|
||||
)
|
||||
.expect(
|
||||
Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def expect_agent_response(
|
||||
self,
|
||||
answer_tokens: list[str],
|
||||
turn_index: int,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
) -> StreamTestBuilder:
|
||||
return (
|
||||
self.expect(
|
||||
Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
)
|
||||
.expect_packets(
|
||||
[
|
||||
create_packet_with_agent_response_delta(token, turn_index)
|
||||
for token in answer_tokens
|
||||
]
|
||||
)
|
||||
.expect(
|
||||
Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=OverallStop(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def run_and_validate(self, stream: Iterator[AnswerStreamPart]) -> None:
|
||||
while self._expected_packets_queue:
|
||||
expected_pkt, forward_count = self._expected_packets_queue.pop(0)
|
||||
if forward_count > 0:
|
||||
self._llm_controller.forward(forward_count)
|
||||
received_pkt = next(stream)
|
||||
|
||||
assert_answer_stream_part_correct(received_pkt, expected_pkt)
|
||||
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_utils import create_chat_session_from_request
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from tests.external_dependency_unit.mock_content_provider import MockWebContent
|
||||
from tests.external_dependency_unit.mock_search_provider import MockWebSearchResult
|
||||
|
||||
|
||||
def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def submit_query(
|
||||
query: str, chat_session_id: UUID | None, db_session: Session, user: User
|
||||
) -> Iterator[AnswerStreamPart]:
|
||||
request = SendMessageRequest(
|
||||
message=query,
|
||||
chat_session_id=chat_session_id,
|
||||
stream=True,
|
||||
chat_session_info=(
|
||||
ChatSessionCreationRequest() if chat_session_id is None else None
|
||||
),
|
||||
)
|
||||
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> ChatSession:
|
||||
return create_chat_session_from_request(
|
||||
chat_session_request=ChatSessionCreationRequest(),
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def create_packet_with_agent_response_delta(token: str, turn_index: int) -> Packet:
|
||||
return Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=AgentResponseDelta(
|
||||
content=token,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_packet_with_reasoning_delta(token: str, turn_index: int) -> Packet:
|
||||
return Packet(
|
||||
placement=create_placement(turn_index),
|
||||
obj=ReasoningDelta(
|
||||
reasoning=token,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_web_search_doc(
|
||||
semantic_identifier: str,
|
||||
link: str,
|
||||
blurb: str,
|
||||
) -> SearchDoc:
|
||||
return SearchDoc(
|
||||
document_id=f"WEB_SEARCH_DOC_{link}",
|
||||
chunk_ind=0,
|
||||
semantic_identifier=semantic_identifier,
|
||||
link=link,
|
||||
blurb=blurb,
|
||||
source_type=DocumentSource.WEB,
|
||||
boost=1,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
|
||||
def mock_web_search_result_to_search_doc(result: MockWebSearchResult) -> SearchDoc:
|
||||
return create_web_search_doc(
|
||||
semantic_identifier=result.title,
|
||||
link=result.link,
|
||||
blurb=result.snippet,
|
||||
)
|
||||
|
||||
|
||||
def mock_web_content_to_search_doc(content: MockWebContent) -> SearchDoc:
|
||||
return create_web_search_doc(
|
||||
semantic_identifier=content.title,
|
||||
link=content.url,
|
||||
blurb=content.title,
|
||||
)
|
||||
|
||||
|
||||
def tokenise(text: str) -> list[str]:
|
||||
return [(token + " ") for token in text.split(" ")]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -0,0 +1,59 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContent
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
|
||||
|
||||
|
||||
class MockWebContent(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
content: str
|
||||
|
||||
def to_web_content(self) -> WebContent:
|
||||
return WebContent(
|
||||
title=self.title,
|
||||
link=self.url,
|
||||
full_content=self.content,
|
||||
published_date=None,
|
||||
scrape_successful=True,
|
||||
)
|
||||
|
||||
|
||||
class ContentProviderController(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def add_content(self, content: MockWebContent) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockContentProvider(WebContentProvider, ContentProviderController):
|
||||
def __init__(self) -> None:
|
||||
self._contents: list[MockWebContent] = []
|
||||
|
||||
def add_content(self, web_content: MockWebContent) -> None:
|
||||
self._contents.append(web_content)
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
filtered_contents = list(
|
||||
filter(lambda web_content: web_content.url in urls, self._contents)
|
||||
)
|
||||
|
||||
return list(
|
||||
map(lambda web_content: web_content.to_web_content(), filtered_contents)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_mock_content_provider() -> Generator[ContentProviderController, None, None]:
|
||||
content_provider = MockContentProvider()
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.open_url.open_url_tool.get_default_content_provider",
|
||||
return_value=content_provider,
|
||||
):
|
||||
yield content_provider
|
||||
130
backend/tests/external_dependency_unit/mock_image_provider.py
Normal file
130
backend/tests/external_dependency_unit/mock_image_provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from litellm.types.utils import ImageObject
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
|
||||
|
||||
class ImageGenerationProviderController(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def add_image(
|
||||
self,
|
||||
data: str,
|
||||
delay: float = 0.0,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockImageGenerationProvider(
|
||||
ImageGenerationProvider, ImageGenerationProviderController
|
||||
):
|
||||
def __init__(self) -> None:
|
||||
self._images: list[str] = []
|
||||
self._delays: list[float] = []
|
||||
|
||||
def add_image(
|
||||
self,
|
||||
data: str,
|
||||
delay: float = 0.0,
|
||||
) -> None:
|
||||
self._images.append(data)
|
||||
self._delays.append(delay)
|
||||
|
||||
@classmethod
|
||||
def validate_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _build_from_credentials(
|
||||
cls,
|
||||
_: ImageGenerationProviderCredentials,
|
||||
) -> ImageGenerationProvider:
|
||||
return cls()
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageResponse:
|
||||
image_data = self._images.pop(0)
|
||||
delay = self._delays.pop(0)
|
||||
|
||||
if delay > 0.0:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# Event loop is running - run sleep in executor to avoid blocking the event loop
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(time.sleep, delay)
|
||||
future.result()
|
||||
except RuntimeError:
|
||||
# No running event loop, use regular thread sleep
|
||||
time.sleep(delay)
|
||||
|
||||
return ImageResponse(
|
||||
created=int(datetime.now().timestamp()),
|
||||
data=[
|
||||
ImageObject(
|
||||
b64_json=image_data,
|
||||
revised_prompt=prompt,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _create_mock_image_generation_llm_config() -> LLMConfig:
|
||||
"""Create a mock LLMConfig for image generation."""
|
||||
return LLMConfig(
|
||||
model_provider="openai",
|
||||
model_name="gpt-image-1",
|
||||
temperature=0.0,
|
||||
api_key="mock-api-key",
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
deployment_name=None,
|
||||
max_input_tokens=100000,
|
||||
custom_config=None,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_mock_image_generation_provider() -> (
|
||||
Generator[ImageGenerationProviderController, None, None]
|
||||
):
|
||||
image_gen_provider = MockImageGenerationProvider()
|
||||
|
||||
with (
|
||||
# Mock the image generation provider factory
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider",
|
||||
return_value=image_gen_provider,
|
||||
),
|
||||
# Mock is_available to return True so the tool is registered
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.ImageGenerationTool.is_available",
|
||||
return_value=True,
|
||||
),
|
||||
# Mock the config lookup in tool_constructor to return a valid LLMConfig
|
||||
patch(
|
||||
"onyx.tools.tool_constructor._get_image_generation_config",
|
||||
return_value=_create_mock_image_generation_llm_config(),
|
||||
),
|
||||
):
|
||||
yield image_gen_provider
|
||||
@@ -6,40 +6,275 @@ import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import Literal
|
||||
from typing import TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ReasoningEffort
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.llm.model_response import Delta
|
||||
from onyx.llm.model_response import FunctionCall
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import StreamingChoice
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LLMResponseType(str, Enum):
|
||||
REASONING = "reasoning"
|
||||
ANSWER = "answer"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
|
||||
class LLMResponse(abc.ABC, BaseModel):
|
||||
type: str = ""
|
||||
|
||||
@abc.abstractmethod
|
||||
def num_tokens(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LLMReasoningResponse(LLMResponse):
|
||||
type: Literal["reasoning"] = LLMResponseType.REASONING.value
|
||||
reasoning_tokens: list[str]
|
||||
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.reasoning_tokens)
|
||||
|
||||
|
||||
class LLMAnswerResponse(LLMResponse):
|
||||
type: Literal["answer"] = LLMResponseType.ANSWER.value
|
||||
answer_tokens: list[str]
|
||||
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.answer_tokens)
|
||||
|
||||
|
||||
class LLMToolCallResponse(LLMResponse):
|
||||
type: Literal["tool_call"] = LLMResponseType.TOOL_CALL.value
|
||||
tool_name: str
|
||||
tool_call_id: str
|
||||
tool_call_argument_tokens: list[str]
|
||||
|
||||
def num_tokens(self) -> int:
|
||||
return (
|
||||
len(self.tool_call_argument_tokens) + 1
|
||||
) # +1 for the tool_call_id and tool_name
|
||||
|
||||
|
||||
class StreamItem(BaseModel):
|
||||
"""Represents a single item in the mock LLM stream with its type."""
|
||||
|
||||
response_type: LLMResponseType
|
||||
data: Any
|
||||
|
||||
|
||||
def _response_to_stream_items(response: LLMResponse) -> list[StreamItem]:
|
||||
match LLMResponseType(response.type):
|
||||
case LLMResponseType.REASONING:
|
||||
response = cast(LLMReasoningResponse, response)
|
||||
return [
|
||||
StreamItem(
|
||||
response_type=LLMResponseType.REASONING,
|
||||
data=token,
|
||||
)
|
||||
for token in response.reasoning_tokens
|
||||
]
|
||||
case LLMResponseType.ANSWER:
|
||||
response = cast(LLMAnswerResponse, response)
|
||||
return [
|
||||
StreamItem(
|
||||
response_type=LLMResponseType.ANSWER,
|
||||
data=token,
|
||||
)
|
||||
for token in response.answer_tokens
|
||||
]
|
||||
case LLMResponseType.TOOL_CALL:
|
||||
response = cast(LLMToolCallResponse, response)
|
||||
return [
|
||||
StreamItem(
|
||||
response_type=LLMResponseType.TOOL_CALL,
|
||||
data={
|
||||
"tool_call_id": response.tool_call_id,
|
||||
"tool_name": response.tool_name,
|
||||
"arguments": None,
|
||||
},
|
||||
)
|
||||
] + [
|
||||
StreamItem(
|
||||
response_type=LLMResponseType.TOOL_CALL,
|
||||
data={
|
||||
"tool_call_id": None,
|
||||
"tool_name": None,
|
||||
"arguments": token,
|
||||
},
|
||||
)
|
||||
for token in response.tool_call_argument_tokens
|
||||
]
|
||||
case _:
|
||||
raise ValueError(f"Unknown response type: {response.type}")
|
||||
|
||||
|
||||
def create_delta_from_stream_item(item: StreamItem) -> Delta:
|
||||
response_type = item.response_type
|
||||
data = item.data
|
||||
if response_type == LLMResponseType.REASONING:
|
||||
return Delta(reasoning_content=data)
|
||||
elif response_type == LLMResponseType.ANSWER:
|
||||
return Delta(content=data)
|
||||
elif response_type == LLMResponseType.TOOL_CALL:
|
||||
# Handle grouped tool calls (list) vs single tool call (dict)
|
||||
if isinstance(data, list):
|
||||
# Multiple tool calls emitted together in the same tick
|
||||
tool_calls = []
|
||||
for tc_data in data:
|
||||
if tc_data["tool_call_id"] is not None:
|
||||
tool_calls.append(
|
||||
ChatCompletionDeltaToolCall(
|
||||
id=tc_data["tool_call_id"],
|
||||
index=tc_data["index"],
|
||||
function=FunctionCall(
|
||||
arguments="",
|
||||
name=tc_data["tool_name"],
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
tool_calls.append(
|
||||
ChatCompletionDeltaToolCall(
|
||||
index=tc_data["index"],
|
||||
id=None,
|
||||
function=FunctionCall(
|
||||
arguments=tc_data["arguments"],
|
||||
name=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
return Delta(tool_calls=tool_calls)
|
||||
else:
|
||||
# Single tool call (original behavior)
|
||||
# First tick has tool_call_id and tool_name, subsequent ticks have arguments
|
||||
if data["tool_call_id"] is not None:
|
||||
return Delta(
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id=data["tool_call_id"],
|
||||
function=FunctionCall(
|
||||
name=data["tool_name"],
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return Delta(
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id=None,
|
||||
function=FunctionCall(
|
||||
name=None,
|
||||
arguments=data["arguments"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response_type}")
|
||||
|
||||
|
||||
class MockLLMController(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def set_response(self, response_tokens: list[str]) -> None:
|
||||
def add_response(self, response: LLMResponse) -> None:
|
||||
"""Add a response to the current stream."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_responses_together(self, *responses: LLMResponse) -> None:
|
||||
"""Add multiple responses that should be emitted together in the same tick."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, n: int) -> None:
|
||||
"""Forward the stream by n tokens."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_till_end(self) -> None:
|
||||
"""Forward the stream until the end."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_max_timeout(self, timeout: float = 5.0) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockLLM(LLM, MockLLMController):
|
||||
def __init__(self) -> None:
|
||||
self.stream_controller: SyncStreamController | None = None
|
||||
self.stream_controller = SyncStreamController[StreamItem]()
|
||||
|
||||
def set_response(self, response_tokens: list[str]) -> None:
|
||||
self.stream_controller = SyncStreamController(response_tokens)
|
||||
def add_response(self, response: LLMResponse) -> None:
|
||||
items = _response_to_stream_items(response)
|
||||
self.stream_controller.queue_items(items)
|
||||
|
||||
def add_responses_together(self, *responses: LLMResponse) -> None:
|
||||
"""Add multiple responses that should be emitted together in the same tick.
|
||||
|
||||
Currently only supports multiple tool call responses being grouped together.
|
||||
The initial tool call info (id, name) for all tool calls will be emitted
|
||||
in a single delta, followed by argument tokens for each tool call.
|
||||
"""
|
||||
tool_calls = [r for r in responses if r.type == LLMResponseType.TOOL_CALL]
|
||||
|
||||
if len(tool_calls) != len(responses):
|
||||
raise ValueError(
|
||||
"add_responses_together currently only supports "
|
||||
"multiple tool call responses"
|
||||
)
|
||||
|
||||
# Create combined first item with all tool call initial info
|
||||
combined_data = [
|
||||
{
|
||||
"index": idx,
|
||||
"tool_call_id": cast(LLMToolCallResponse, tc).tool_call_id,
|
||||
"tool_name": cast(LLMToolCallResponse, tc).tool_name,
|
||||
"arguments": None,
|
||||
}
|
||||
for idx, tc in enumerate(tool_calls)
|
||||
]
|
||||
combined_item = StreamItem(
|
||||
response_type=LLMResponseType.TOOL_CALL,
|
||||
data=combined_data,
|
||||
)
|
||||
self.stream_controller.queue_items([combined_item])
|
||||
|
||||
# Add argument tokens for each tool call with their index
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
tc = cast(LLMToolCallResponse, tc)
|
||||
for token in tc.tool_call_argument_tokens:
|
||||
item = StreamItem(
|
||||
response_type=LLMResponseType.TOOL_CALL,
|
||||
data=[
|
||||
{
|
||||
"index": idx,
|
||||
"tool_call_id": None,
|
||||
"tool_name": None,
|
||||
"arguments": token,
|
||||
}
|
||||
],
|
||||
)
|
||||
self.stream_controller.queue_items([item])
|
||||
|
||||
def forward(self, n: int) -> None:
|
||||
if self.stream_controller:
|
||||
@@ -53,6 +288,9 @@ class MockLLM(LLM, MockLLMController):
|
||||
else:
|
||||
raise ValueError("No response set")
|
||||
|
||||
def set_max_timeout(self, timeout: float = 5.0) -> None:
|
||||
self.stream_controller.timeout = timeout
|
||||
|
||||
@property
|
||||
def config(self) -> LLMConfig:
|
||||
return LLMConfig(
|
||||
@@ -89,16 +327,14 @@ class MockLLM(LLM, MockLLMController):
|
||||
if not self.stream_controller:
|
||||
return
|
||||
|
||||
for idx, token in enumerate(self.stream_controller):
|
||||
for idx, item in enumerate(self.stream_controller):
|
||||
yield ModelResponseStream(
|
||||
id="chatcmp-123",
|
||||
created="1",
|
||||
choice=StreamingChoice(
|
||||
finish_reason=None,
|
||||
index=idx,
|
||||
delta=Delta(
|
||||
content=token,
|
||||
),
|
||||
index=0, # Choice index should stay at 0 for all items in the same stream
|
||||
delta=create_delta_from_stream_item(item),
|
||||
),
|
||||
usage=None,
|
||||
)
|
||||
@@ -108,18 +344,22 @@ class StreamTimeoutError(Exception):
|
||||
"""Raised when the stream controller times out waiting for tokens."""
|
||||
|
||||
|
||||
class SyncStreamController:
|
||||
def __init__(self, tokens: list[str], timeout: float = 5.0) -> None:
|
||||
self.tokens = tokens
|
||||
class SyncStreamController(Generic[T]):
|
||||
def __init__(self, items: list[T] | None = None, timeout: float = 5.0) -> None:
|
||||
self.items = items if items is not None else []
|
||||
self.position = 0
|
||||
self.pending: list[int] = [] # The indices of the tokens that are pending
|
||||
self.timeout = timeout # Maximum time to wait for tokens before failing
|
||||
|
||||
self._has_pending = threading.Event()
|
||||
|
||||
def queue_items(self, new_items: list[T]) -> None:
|
||||
"""Queue additional tokens to the stream (for chaining responses like reasoning + tool calls)."""
|
||||
self.items.extend(new_items)
|
||||
|
||||
def forward(self, n: int) -> None:
|
||||
"""Queue the next n tokens to be yielded"""
|
||||
end = min(self.position + n, len(self.tokens))
|
||||
end = min(self.position + n, len(self.items))
|
||||
self.pending.extend(range(self.position, end))
|
||||
self.position = end
|
||||
|
||||
@@ -127,29 +367,29 @@ class SyncStreamController:
|
||||
self._has_pending.set()
|
||||
|
||||
def forward_till_end(self) -> None:
|
||||
self.forward(len(self.tokens) - self.position)
|
||||
self.forward(len(self.items) - self.position)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self.position >= len(self.tokens) and not self.pending
|
||||
return self.position >= len(self.items) and not self.pending
|
||||
|
||||
def __iter__(self) -> SyncStreamController:
|
||||
def __iter__(self) -> SyncStreamController[T]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
def __next__(self) -> T:
|
||||
start_time = time.monotonic()
|
||||
while not self.is_done:
|
||||
if self.pending:
|
||||
token_idx = self.pending.pop(0)
|
||||
item_idx = self.pending.pop(0)
|
||||
if not self.pending:
|
||||
self._has_pending.clear()
|
||||
return self.tokens[token_idx]
|
||||
return self.items[item_idx]
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed >= self.timeout:
|
||||
raise StreamTimeoutError(
|
||||
f"Stream controller timed out after {self.timeout}s waiting for tokens. "
|
||||
f"Position: {self.position}/{len(self.tokens)}, Pending: {len(self.pending)}"
|
||||
f"Position: {self.position}/{len(self.items)}, Pending: {len(self.pending)}"
|
||||
)
|
||||
|
||||
self._has_pending.wait(timeout=0.1)
|
||||
|
||||
183
backend/tests/external_dependency_unit/mock_search_pipeline.py
Normal file
183
backend/tests/external_dependency_unit/mock_search_pipeline.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.interfaces import LLM
|
||||
|
||||
|
||||
def run_functions_tuples_sequential(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
timeout: float | None = None,
|
||||
timeout_callback: Callable | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
A sequential replacement for run_functions_tuples_in_parallel.
|
||||
Useful in tests to make parallel tool calls deterministic.
|
||||
"""
|
||||
results = []
|
||||
for func, args in functions_with_args:
|
||||
try:
|
||||
results.append(func(*args))
|
||||
except Exception:
|
||||
if allow_failures:
|
||||
results.append(None)
|
||||
else:
|
||||
raise
|
||||
return results
|
||||
|
||||
|
||||
class MockInternalSearchResult(BaseModel):
|
||||
document_id: str
|
||||
source_type: DocumentSource
|
||||
semantic_identifier: str
|
||||
chunk_ind: int
|
||||
|
||||
def to_inference_chunk(self) -> InferenceChunk:
|
||||
return InferenceChunk(
|
||||
document_id=f"{self.source_type.value.upper()}_{self.document_id}",
|
||||
source_type=self.source_type,
|
||||
semantic_identifier=self.semantic_identifier,
|
||||
title=self.semantic_identifier,
|
||||
chunk_id=self.chunk_ind,
|
||||
blurb="",
|
||||
content="",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
boost=0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
def to_search_doc(self) -> SearchDoc:
|
||||
return SearchDoc(
|
||||
document_id=f"{self.source_type.value.upper()}_{self.document_id}",
|
||||
chunk_ind=self.chunk_ind,
|
||||
semantic_identifier=self.semantic_identifier,
|
||||
link=None,
|
||||
blurb="",
|
||||
source_type=self.source_type,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
score=1.0,
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
|
||||
class SearchPipelineController:
|
||||
def __init__(self) -> None:
|
||||
self.search_results: dict[str, list[MockInternalSearchResult]] = {}
|
||||
|
||||
def add_search_results(
|
||||
self, query: str, results: list[MockInternalSearchResult]
|
||||
) -> None:
|
||||
self.search_results[query] = results
|
||||
|
||||
def get_search_results(self, query: str) -> list[InferenceChunk]:
|
||||
return [
|
||||
result.to_inference_chunk() for result in self.search_results.get(query, [])
|
||||
]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_mock_search_pipeline(
|
||||
connectors: list[DocumentSource],
|
||||
) -> Generator[SearchPipelineController, None, None]:
|
||||
"""Mock the search pipeline and connector availability.
|
||||
|
||||
Args:
|
||||
connectors: List of DocumentSource types to pretend are available.
|
||||
Pass an empty list to simulate no connectors.
|
||||
"""
|
||||
controller = SearchPipelineController()
|
||||
|
||||
def mock_check_connectors_exist(db_session: Session) -> bool:
|
||||
return len(connectors) > 0
|
||||
|
||||
def mock_check_federated_connectors_exist(db_session: Session) -> bool:
|
||||
# For now, federated connectors are not mocked as available
|
||||
return False
|
||||
|
||||
def mock_check_user_files_exist(db_session: Session) -> bool:
|
||||
# For now, user files are not mocked as available
|
||||
return False
|
||||
|
||||
def mock_fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
return connectors
|
||||
|
||||
def override_search_pipeline(
|
||||
chunk_search_request: ChunkSearchRequest,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
persona: Persona | None,
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
project_id: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
return controller.get_search_results(chunk_search_request.query)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.search_pipeline",
|
||||
new=override_search_pipeline,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
|
||||
new=mock_check_connectors_exist,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
|
||||
new=mock_check_federated_connectors_exist,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.semantic_query_rephrase",
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.keyword_query_expansion",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_runner.run_functions_tuples_in_parallel",
|
||||
new=run_functions_tuples_sequential,
|
||||
),
|
||||
patch(
|
||||
"onyx.db.connector.check_connectors_exist",
|
||||
new=mock_check_connectors_exist,
|
||||
),
|
||||
patch(
|
||||
"onyx.db.connector.check_federated_connectors_exist",
|
||||
new=mock_check_federated_connectors_exist,
|
||||
),
|
||||
patch(
|
||||
"onyx.db.connector.check_user_files_exist",
|
||||
new=mock_check_user_files_exist,
|
||||
),
|
||||
patch(
|
||||
"onyx.db.connector.fetch_unique_document_sources",
|
||||
new=mock_fetch_unique_document_sources,
|
||||
),
|
||||
):
|
||||
yield controller
|
||||
@@ -0,0 +1,97 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.web_search import fetch_web_search_provider_by_name
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchProvider
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
|
||||
class MockWebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
snippet: str
|
||||
|
||||
def to_web_search_result(self) -> WebSearchResult:
|
||||
return WebSearchResult(
|
||||
title=self.title,
|
||||
link=self.link,
|
||||
snippet=self.snippet,
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
|
||||
|
||||
class WebProviderController(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def add_results(self, query: str, results: list[MockWebSearchResult]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockWebProvider(WebSearchProvider, WebProviderController):
|
||||
def __init__(self) -> None:
|
||||
self._results: dict[str, list[MockWebSearchResult]] = defaultdict(list)
|
||||
|
||||
def add_results(self, query: str, results: list[MockWebSearchResult]) -> None:
|
||||
self._results[query] = results
|
||||
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
return list(
|
||||
map(lambda result: result.to_web_search_result(), self._results[query])
|
||||
)
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def add_web_provider_to_db(db_session: Session) -> None:
|
||||
# Write a provider to the database
|
||||
if fetch_web_search_provider_by_name(name="Test Provider 2", db_session=db_session):
|
||||
return
|
||||
|
||||
provider = InternetSearchProvider(
|
||||
name="Test Provider 2",
|
||||
provider_type=WebSearchProviderType.EXA.value,
|
||||
api_key="test-api-key",
|
||||
config={},
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
db_session.add(provider)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_web_provider_from_db(db_session: Session) -> None:
|
||||
provider = fetch_web_search_provider_by_name(
|
||||
name="Test Provider 2", db_session=db_session
|
||||
)
|
||||
if provider is not None:
|
||||
db_session.delete(provider)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_mock_web_provider(
|
||||
db_session: Session,
|
||||
) -> Generator[WebProviderController, None, None]:
|
||||
web_provider = MockWebProvider()
|
||||
|
||||
# Write the tool to the database
|
||||
add_web_provider_to_db(db_session)
|
||||
|
||||
# override the build function
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.build_search_provider_from_config",
|
||||
return_value=web_provider,
|
||||
):
|
||||
yield web_provider
|
||||
|
||||
delete_web_provider_from_db(db_session)
|
||||
@@ -8,14 +8,22 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
generate_opensearch_filtered_access_control_list,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
@@ -42,14 +50,22 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
|
||||
|
||||
def _create_test_document_chunk(
|
||||
document_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
tenant_state: TenantState,
|
||||
chunk_index: int = 0,
|
||||
content_vector: list[float] | None = None,
|
||||
title: str | None = None,
|
||||
title_vector: list[float] | None = None,
|
||||
public: bool = True,
|
||||
hidden: bool = False,
|
||||
document_access: DocumentAccess = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
),
|
||||
source_type: DocumentSource = DocumentSource.FILE,
|
||||
last_updated: datetime | None = None,
|
||||
) -> DocumentChunk:
|
||||
if content_vector is None:
|
||||
# Generate dummy vector - 128 dimensions for fast testing.
|
||||
@@ -59,11 +75,6 @@ def _create_test_document_chunk(
|
||||
if title is not None and title_vector is None:
|
||||
title_vector = [0.2] * 128
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# We only store millisecond precision, so to make sure asserts work in this
|
||||
# test file manually lose some precision from datetime.now().
|
||||
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
|
||||
|
||||
return DocumentChunk(
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
@@ -71,11 +82,13 @@ def _create_test_document_chunk(
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type="test_source",
|
||||
source_type=source_type.value,
|
||||
metadata_list=None,
|
||||
last_updated=now,
|
||||
public=public,
|
||||
access_control_list=[],
|
||||
last_updated=last_updated,
|
||||
public=document_access.is_public,
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
document_access
|
||||
),
|
||||
hidden=hidden,
|
||||
global_boost=0,
|
||||
semantic_identifier="Test semantic identifier",
|
||||
@@ -331,6 +344,9 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Content to retrieve",
|
||||
tenant_state=tenant_state,
|
||||
# We only store second precision, so to make sure asserts work in
|
||||
# this test we'll deliberately lose some precision.
|
||||
last_updated=datetime.now(timezone.utc).replace(microsecond=0),
|
||||
)
|
||||
test_client.index_document(document=original_doc)
|
||||
|
||||
@@ -471,6 +487,8 @@ class TestOpenSearchClient:
|
||||
search_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="delete-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -483,6 +501,8 @@ class TestOpenSearchClient:
|
||||
keep_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="keep-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -510,7 +530,6 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Original content",
|
||||
tenant_state=tenant_state,
|
||||
public=True,
|
||||
hidden=False,
|
||||
)
|
||||
test_client.index_document(document=doc)
|
||||
@@ -561,10 +580,13 @@ class TestOpenSearchClient:
|
||||
properties_to_update={"hidden": True},
|
||||
)
|
||||
|
||||
def test_search_basic(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests basic search functionality."""
|
||||
"""Tests hybrid search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
@@ -574,24 +596,24 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index multiple documents with different content and vectors.
|
||||
# Index documents.
|
||||
docs = {
|
||||
"search-doc-1": _create_test_document_chunk(
|
||||
document_id="search-doc-1",
|
||||
"doc-1": _create_test_document_chunk(
|
||||
document_id="doc-1",
|
||||
chunk_index=0,
|
||||
content="Python programming language tutorial",
|
||||
content_vector=_generate_test_vector(0.1),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-2": _create_test_document_chunk(
|
||||
document_id="search-doc-2",
|
||||
"doc-2": _create_test_document_chunk(
|
||||
document_id="doc-2",
|
||||
chunk_index=0,
|
||||
content="How to make cheese",
|
||||
content_vector=_generate_test_vector(0.2),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-3": _create_test_document_chunk(
|
||||
document_id="search-doc-3",
|
||||
"doc-3": _create_test_document_chunk(
|
||||
document_id="doc-3",
|
||||
chunk_index=0,
|
||||
content="C++ for newborns",
|
||||
content_vector=_generate_test_vector(0.15),
|
||||
@@ -613,78 +635,10 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 3
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id
|
||||
in ["search-doc-1", "search-doc-2", "search-doc-3"]
|
||||
for chunk in results
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
|
||||
# Make sure there is some kind of match highlight for the first hit. We
|
||||
# don't expect highlights for any other hit.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents.
|
||||
docs = {
|
||||
"pipeline-doc-1": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-1",
|
||||
chunk_index=0,
|
||||
content="Machine learning algorithms for single-celled organisms",
|
||||
content_vector=_generate_test_vector(0.3),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"pipeline-doc-2": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-2",
|
||||
chunk_index=0,
|
||||
content="Deep learning shallow neural networks",
|
||||
content_vector=_generate_test_vector(0.35),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query.
|
||||
query_text = "machine learning"
|
||||
query_vector = _generate_test_vector(0.32)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -693,23 +647,26 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 2
|
||||
assert len(results) == len(docs)
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
|
||||
for chunk in results
|
||||
)
|
||||
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
for i, chunk in enumerate(results):
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Make sure there is some kind of match highlight only for the first
|
||||
# result. The other results are so bad they're not expected to have
|
||||
# match highlights.
|
||||
if i == 0:
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_empty_index(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search on an empty index returns an empty list."""
|
||||
# Precondition.
|
||||
@@ -731,19 +688,28 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_filters(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests search filters for public/hidden documents and tenant isolation.
|
||||
Tests search filters for ACL, hidden documents, and tenant isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
@@ -757,29 +723,47 @@ class TestOpenSearchClient:
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc-1": _create_test_document_chunk(
|
||||
document_id="public-doc-1",
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc-1": _create_test_document_chunk(
|
||||
document_id="hidden-doc-1",
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-1": _create_test_document_chunk(
|
||||
document_id="private-doc-1",
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
public=False,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
@@ -787,7 +771,6 @@ class TestOpenSearchClient:
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
@@ -798,9 +781,6 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search with default filters (public=True, hidden=False).
|
||||
# The DocumentQuery.get_hybrid_search_query uses filters that should
|
||||
# only return public, non-hidden documents.
|
||||
query_text = "document content"
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -809,24 +789,41 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document.
|
||||
assert len(results) == 1
|
||||
assert results[0].document_chunk.document_id == "public-doc-1"
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# NOTE: This test is not explicitly testing for how well results are
|
||||
# ordered; we're just assuming which doc will be the first result here.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == docs["public-doc-1"]
|
||||
assert results[0].document_chunk == docs["public-doc"]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == docs["private-doc-user-a"]
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
@@ -849,52 +846,54 @@ class TestOpenSearchClient:
|
||||
# Vectors closer to query_vector (0.1) should rank higher.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="highly-relevant-1",
|
||||
document_id="highly-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence and machine learning transform technology",
|
||||
content_vector=_generate_test_vector(
|
||||
0.1
|
||||
), # Very close to query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="somewhat-relevant-1",
|
||||
document_id="somewhat-relevant",
|
||||
chunk_index=0,
|
||||
content="Computer programming with various languages",
|
||||
content_vector=_generate_test_vector(0.5), # Far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="not-very-relevant-1",
|
||||
document_id="not-very-relevant",
|
||||
chunk_index=0,
|
||||
content="Cooking recipes for delicious meals",
|
||||
content_vector=_generate_test_vector(
|
||||
0.9
|
||||
), # Very far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
# These should be filtered out by public/hidden filters.
|
||||
_create_test_document_chunk(
|
||||
document_id="hidden-but-relevant-1",
|
||||
document_id="hidden-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence research papers",
|
||||
content_vector=_generate_test_vector(0.05), # Very close but hidden.
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="private-but-relevant-1",
|
||||
document_id="private-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence industry analysis",
|
||||
content_vector=_generate_test_vector(0.08), # Very close but private.
|
||||
public=False,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
@@ -905,7 +904,7 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query matching "highly-relevant-1" most closely.
|
||||
# Search query matching "highly-relevant" most closely.
|
||||
query_text = "artificial intelligence"
|
||||
query_vector = _generate_test_vector(0.1)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -914,6 +913,9 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# Explicitly pass in an empty list to enforce private doc filtering.
|
||||
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -925,15 +927,15 @@ class TestOpenSearchClient:
|
||||
# Should only get public, non-hidden documents (3 out of 5).
|
||||
assert len(results) == 3
|
||||
result_ids = [chunk.document_chunk.document_id for chunk in results]
|
||||
assert "highly-relevant-1" in result_ids
|
||||
assert "somewhat-relevant-1" in result_ids
|
||||
assert "not-very-relevant-1" in result_ids
|
||||
assert "highly-relevant" in result_ids
|
||||
assert "somewhat-relevant" in result_ids
|
||||
assert "not-very-relevant" in result_ids
|
||||
# Filtered out by public/hidden constraints.
|
||||
assert "hidden-but-relevant-1" not in result_ids
|
||||
assert "private-but-relevant-1" not in result_ids
|
||||
assert "hidden-but-relevant" not in result_ids
|
||||
assert "private-but-relevant" not in result_ids
|
||||
|
||||
# Most relevant document should be first due to normalization pipeline.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant-1"
|
||||
# Most relevant document should be first.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant"
|
||||
|
||||
# Make sure there is some kind of match highlight for the most relevant
|
||||
# result.
|
||||
@@ -1014,6 +1016,8 @@ class TestOpenSearchClient:
|
||||
verify_query_x = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1026,6 +1030,8 @@ class TestOpenSearchClient:
|
||||
verify_query_y = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-y",
|
||||
tenant_state=tenant_y,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1113,6 +1119,8 @@ class TestOpenSearchClient:
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-1",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1133,3 +1141,176 @@ class TestOpenSearchClient:
|
||||
for chunk in doc1_chunks
|
||||
}
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
private ones.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for all documents.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="private-doc-user-a",
|
||||
tenant_state=tenant_state,
|
||||
# This is the input under test, notice None for acl.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
# Even though this doc is private, because we supplied None for acl we
|
||||
# were able to retrieve it.
|
||||
assert len(chunk_ids) == 1
|
||||
# Since this is a chunk ID, it will have the doc ID in it plus other
|
||||
# stuff we don't care about in this test.
|
||||
assert chunk_ids[0].startswith("private-doc-user-a")
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests the time cutoff filter works."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index docs with various ages.
|
||||
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
one_week_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
six_months_ago = datetime.now(timezone.utc) - timedelta(days=180)
|
||||
one_year_ago = datetime.now(timezone.utc) - timedelta(days=365)
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="one-day-ago",
|
||||
content="Good match",
|
||||
last_updated=one_day_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="one-year-ago",
|
||||
content="Good match",
|
||||
last_updated=one_year_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="no-last-updated",
|
||||
# Since we test for result ordering in the postconditions, let's
|
||||
# just make this content slightly less of a match with the query
|
||||
# so this test is not flaky from the ordering of the results.
|
||||
content="Still an ok match",
|
||||
last_updated=None,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for documents updated in the last week.
|
||||
last_week_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=one_week_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
last_six_months_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=six_months_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
last_week_results = test_client.search(
|
||||
body=last_week_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
last_six_months_results = test_client.search(
|
||||
body=last_six_months_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# We expect to only get one-day-ago.
|
||||
assert len(last_week_results) == 1
|
||||
assert last_week_results[0].document_chunk.document_id == "one-day-ago"
|
||||
# We expect to get one-day-ago and no-last-updated since six months >
|
||||
# ASSUMED_DOCUMENT_AGE_DAYS.
|
||||
assert len(last_six_months_results) == 2
|
||||
assert last_six_months_results[0].document_chunk.document_id == "one-day-ago"
|
||||
assert (
|
||||
last_six_months_results[1].document_chunk.document_id == "no-last-updated"
|
||||
)
|
||||
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProjectManager:
|
||||
) -> List[UserProjectSnapshot]:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -56,7 +56,7 @@ class ProjectManager:
|
||||
) -> bool:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
|
||||
provider_id = _activate_exa_provider(admin_user)
|
||||
assert isinstance(provider_id, int)
|
||||
|
||||
search_request = {"queries": ["latest ai research news"], "max_results": 3}
|
||||
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
|
||||
|
||||
lite_response = requests.post(
|
||||
f"{API_SERVER_URL}/web-search/search-lite",
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
|
||||
class TestConvertToMetadataValue:
|
||||
"""Tests for the _convert_to_metadata_value helper function."""
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""String values should be returned as-is."""
|
||||
assert _convert_to_metadata_value("hello") == "hello"
|
||||
assert _convert_to_metadata_value("") == ""
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
"""Boolean True should be converted to string 'True'."""
|
||||
assert _convert_to_metadata_value(True) == "True"
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
"""Boolean False should be converted to string 'False'."""
|
||||
assert _convert_to_metadata_value(False) == "False"
|
||||
|
||||
def test_integer_value(self) -> None:
|
||||
"""Integer values should be converted to string."""
|
||||
assert _convert_to_metadata_value(42) == "42"
|
||||
assert _convert_to_metadata_value(0) == "0"
|
||||
assert _convert_to_metadata_value(-100) == "-100"
|
||||
|
||||
def test_float_value(self) -> None:
|
||||
"""Float values should be converted to string."""
|
||||
assert _convert_to_metadata_value(3.14) == "3.14"
|
||||
assert _convert_to_metadata_value(0.0) == "0.0"
|
||||
assert _convert_to_metadata_value(-2.5) == "-2.5"
|
||||
|
||||
def test_list_of_strings(self) -> None:
|
||||
"""List of strings should remain as list of strings."""
|
||||
result = _convert_to_metadata_value(["a", "b", "c"])
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_list_of_mixed_types(self) -> None:
|
||||
"""List with mixed types should have all items converted to strings."""
|
||||
result = _convert_to_metadata_value([1, True, 3.14, "text"])
|
||||
assert result == ["1", "True", "3.14", "text"]
|
||||
|
||||
def test_empty_list(self) -> None:
|
||||
"""Empty list should return empty list."""
|
||||
assert _convert_to_metadata_value([]) == []
|
||||
|
||||
|
||||
class TestYieldDocBatches:
|
||||
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
|
||||
|
||||
@pytest.fixture
|
||||
def connector(self) -> SalesforceConnector:
|
||||
"""Create a SalesforceConnector instance with mocked sf_client."""
|
||||
connector = SalesforceConnector(
|
||||
batch_size=10,
|
||||
requested_objects=["Opportunity"],
|
||||
)
|
||||
# Mock the sf_client property
|
||||
mock_sf_client = MagicMock()
|
||||
mock_sf_client.sf_instance = "test.salesforce.com"
|
||||
connector._sf_client = mock_sf_client
|
||||
return connector
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sf_db(self) -> MagicMock:
|
||||
"""Create a mock OnyxSalesforceSQLite object."""
|
||||
return MagicMock()
|
||||
|
||||
def _create_salesforce_object(
|
||||
self,
|
||||
object_id: str,
|
||||
object_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> SalesforceObject:
|
||||
"""Helper to create a SalesforceObject with required fields."""
|
||||
# Ensure required fields are present
|
||||
data.setdefault(ID_FIELD, object_id)
|
||||
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
|
||||
data.setdefault(NAME_FIELD, f"Test {object_type}")
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_metadata_type_conversion_for_opportunity(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that Opportunity metadata fields are properly type-converted."""
|
||||
parent_id = "006bm000006kyDpAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create a parent object with various data types in the fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Test Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"Account": "Acme Corp", # string - should become "account" metadata
|
||||
"FiscalQuarter": 2, # int - should be converted to "2"
|
||||
"FiscalYear": 2024, # int - should be converted to "2024"
|
||||
"IsClosed": False, # bool - should be converted to "False"
|
||||
"StageName": "Prospecting", # string
|
||||
"Type": "New Business", # string
|
||||
"Amount": 50000.50, # float - should be converted to "50000.50"
|
||||
"CloseDate": "2024-06-30", # string
|
||||
"Probability": 75, # int - should be converted to "75"
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
# Setup mock sf_db
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create a mock document that convert_sf_object_to_doc will return
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Test Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
# Track parent changes
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# Call _yield_doc_batches
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify we got one batch with one document
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify metadata type conversions
|
||||
# All values should be strings (or list of strings)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["account"] == "Acme Corp" # string stays string
|
||||
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
|
||||
assert doc.metadata["fiscal_year"] == "2024" # int -> str
|
||||
assert doc.metadata["is_closed"] == "False" # bool -> str
|
||||
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
|
||||
assert doc.metadata["type"] == "New Business" # string stays string
|
||||
assert (
|
||||
doc.metadata["amount"] == "50000.5"
|
||||
) # float -> str (Python drops trailing zeros)
|
||||
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
|
||||
assert doc.metadata["probability"] == "75" # int -> str
|
||||
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
|
||||
|
||||
# Verify parent was counted
|
||||
assert parents_changed == 1
|
||||
assert type_to_processed[parent_type] == 1
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_missing_optional_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing optional metadata fields are not added."""
|
||||
parent_id = "006bm000006kyDqAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create parent object with only some fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Minimal Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"StageName": "Closed Won",
|
||||
# Notably missing: Amount, Probability, FiscalQuarter, etc.
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Minimal Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only present fields should be in metadata
|
||||
assert "stage_name" in doc.metadata
|
||||
assert doc.metadata["stage_name"] == "Closed Won"
|
||||
assert "name" in doc.metadata
|
||||
assert doc.metadata["name"] == "Minimal Opportunity"
|
||||
|
||||
# Missing fields should not be in metadata
|
||||
assert "amount" not in doc.metadata
|
||||
assert "probability" not in doc.metadata
|
||||
assert "fiscal_quarter" not in doc.metadata
|
||||
assert "fiscal_year" not in doc.metadata
|
||||
assert "is_closed" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_contact_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test metadata conversion for Contact object type."""
|
||||
parent_id = "003bm00000EjHCjAAN"
|
||||
parent_type = "Contact"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "John Doe",
|
||||
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
|
||||
"Account": "Globex Corp",
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z",
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="John Doe",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify Contact-specific metadata
|
||||
assert doc.metadata["object_type"] == "Contact"
|
||||
assert doc.metadata["account"] == "Globex Corp"
|
||||
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
|
||||
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_no_default_attributes_for_unknown_type(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that unknown object types only get object_type metadata."""
|
||||
parent_id = "001bm00000fd9Z3AAI"
|
||||
parent_type = "CustomObject__c"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Custom Record",
|
||||
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
|
||||
"CustomField__c": "custom value",
|
||||
"NumberField__c": 123,
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Custom Record",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only object_type should be set for unknown types
|
||||
assert doc.metadata["object_type"] == "CustomObject__c"
|
||||
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
|
||||
assert "CustomField__c" not in doc.metadata
|
||||
assert "NumberField__c" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_skips_missing_parent_objects(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing parent objects are skipped gracefully."""
|
||||
parent_id = "006bm000006kyDrAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# get_record returns None for missing object
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = None
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Should yield one empty batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 0
|
||||
|
||||
# convert_sf_object_to_doc should not have been called
|
||||
mock_convert.assert_not_called()
|
||||
|
||||
# Parents changed should still be 0
|
||||
assert parents_changed == 0
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_multiple_documents_batching(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that multiple documents are correctly batched."""
|
||||
# Create 3 parent objects
|
||||
parent_ids = [
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
]
|
||||
parent_type = "Opportunity"
|
||||
|
||||
parent_objects = [
|
||||
self._create_salesforce_object(
|
||||
pid,
|
||||
parent_type,
|
||||
{
|
||||
ID_FIELD: pid,
|
||||
NAME_FIELD: f"Opportunity {i}",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"IsClosed": i % 2 == 0, # alternating bool values
|
||||
"Amount": 1000.0 * (i + 1),
|
||||
},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
|
||||
# Setup mock to return all three
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
|
||||
)
|
||||
mock_sf_db.get_record.side_effect = parent_objects
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create mock documents
|
||||
mock_docs = [
|
||||
Document(
|
||||
id=f"SALESFORCE_{pid}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=f"Opportunity {i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
mock_convert.side_effect = mock_docs
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
# With batch_size=10, all 3 docs should be in one batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 3
|
||||
|
||||
# Verify each document has correct metadata
|
||||
for i, doc in enumerate(batches[0]):
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["is_closed"] == str(i % 2 == 0)
|
||||
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
|
||||
|
||||
assert type_to_processed[parent_type] == 3
|
||||
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
|
||||
|
||||
def _mock_user(
|
||||
user_id: UUID | None = None, email: str = "test@example.com"
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.oauth_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
def _make_query_chain() -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
return chain
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_nulls_out_document_set_ownership(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Verify DocumentSet.user_id is nulled out (update, not delete)
|
||||
doc_set_chain = query_chains[DocumentSet]
|
||||
doc_set_chain.filter.assert_called()
|
||||
doc_set_chain.filter.return_value.update.assert_called_once_with(
|
||||
{DocumentSet.user_id: None}
|
||||
)
|
||||
|
||||
# Verify Persona.user_id is nulled out (update, not delete)
|
||||
persona_chain = query_chains[Persona]
|
||||
persona_chain.filter.assert_called()
|
||||
persona_chain.filter.return_value.update.assert_called_once_with(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_cleans_up_join_tables(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Join tables should be deleted (not updated)
|
||||
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
|
||||
chain = query_chains[model]
|
||||
chain.filter.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_commits_and_removes_invited(
|
||||
_mock_ee: Any, mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user(email="deleted@example.com")
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_called_once_with(user)
|
||||
db_session.commit.assert_called_once()
|
||||
mock_remove_invited.assert_called_once_with("deleted@example.com")
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_deletes_oauth_accounts(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
oauth1 = MagicMock()
|
||||
oauth2 = MagicMock()
|
||||
user.oauth_accounts = [oauth1, oauth2]
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_any_call(oauth1)
|
||||
db_session.delete.assert_any_call(oauth2)
|
||||
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
@@ -221,6 +221,13 @@ services:
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
|
||||
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
|
||||
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
|
||||
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
|
||||
@@ -63,6 +63,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -82,6 +82,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -129,6 +129,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
|
||||
## CORS origins for MCP clients (comma-separated list)
|
||||
# MCP_SERVER_CORS_ORIGINS=
|
||||
|
||||
## Discord Bot Configuration
|
||||
## The Discord bot allows users to interact with Onyx from Discord servers
|
||||
## Bot token from Discord Developer Portal (required to enable the bot)
|
||||
# DISCORD_BOT_TOKEN=
|
||||
## Command prefix for bot commands (default: "!")
|
||||
# DISCORD_BOT_INVOKE_CHAR=!
|
||||
|
||||
## Celery Configuration
|
||||
# CELERY_BROKER_POOL_LIMIT=
|
||||
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=
|
||||
|
||||
@@ -582,29 +582,33 @@ else
|
||||
fi
|
||||
|
||||
# Ask for authentication schema
|
||||
echo ""
|
||||
print_info "Which authentication schema would you like to set up?"
|
||||
echo ""
|
||||
echo "1) Basic - Username/password authentication"
|
||||
echo "2) No Auth - Open access (development/testing)"
|
||||
echo ""
|
||||
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
|
||||
echo ""
|
||||
# echo ""
|
||||
# print_info "Which authentication schema would you like to set up?"
|
||||
# echo ""
|
||||
# echo "1) Basic - Username/password authentication"
|
||||
# echo "2) No Auth - Open access (development/testing)"
|
||||
# echo ""
|
||||
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
|
||||
# echo ""
|
||||
|
||||
case "${AUTH_CHOICE:-1}" in
|
||||
1)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Selected: Basic authentication"
|
||||
;;
|
||||
2)
|
||||
AUTH_SCHEMA="disabled"
|
||||
print_info "Selected: No authentication"
|
||||
;;
|
||||
*)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Invalid choice, using basic authentication"
|
||||
;;
|
||||
esac
|
||||
# case "${AUTH_CHOICE:-1}" in
|
||||
# 1)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Selected: Basic authentication"
|
||||
# ;;
|
||||
# # 2)
|
||||
# # AUTH_SCHEMA="disabled"
|
||||
# # print_info "Selected: No authentication"
|
||||
# # ;;
|
||||
# *)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Invalid choice, using basic authentication"
|
||||
# ;;
|
||||
# esac
|
||||
|
||||
# TODO (jessica): Uncomment this once no auth users still have an account
|
||||
# Use basic auth by default
|
||||
AUTH_SCHEMA="basic"
|
||||
|
||||
# Create .env file from template
|
||||
print_info "Creating .env file with your selections..."
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.19
|
||||
version: 0.4.20
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
{{- if .Values.discordbot.enabled }}
|
||||
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
|
||||
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-discordbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
|
||||
{{- with .Values.discordbot.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 8 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.discordbot.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: discordbot
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
|
||||
imagePullPolicy: {{ .Values.global.pullPolicy }}
|
||||
command: ["python", "onyx/onyxbot/discord/client.py"]
|
||||
resources:
|
||||
{{- toYaml .Values.discordbot.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx.envSecrets" . | nindent 12}}
|
||||
# Discord bot token - required for bot to connect
|
||||
{{- if .Values.discordbot.botToken }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
value: {{ .Values.discordbot.botToken | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.discordbot.botTokenSecretName }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.discordbot.botTokenSecretName }}
|
||||
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
|
||||
{{- end }}
|
||||
# Command prefix for bot commands (default: "!")
|
||||
{{- if .Values.discordbot.invokeChar }}
|
||||
- name: DISCORD_BOT_INVOKE_CHAR
|
||||
value: {{ .Values.discordbot.invokeChar | quote }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -701,6 +701,44 @@ celery_worker_user_file_processing:
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
# Discord bot for Onyx
|
||||
# The bot offloads message processing to scalable API pods via HTTP requests.
|
||||
discordbot:
|
||||
enabled: false # Disabled by default - requires bot token configuration
|
||||
# Bot token can be provided directly or via a Kubernetes secret
|
||||
# Option 1: Direct token (not recommended for production)
|
||||
botToken: ""
|
||||
# Option 2: Reference a Kubernetes secret (recommended)
|
||||
botTokenSecretName: "" # Name of the secret containing the bot token
|
||||
botTokenSecretKey: "token" # Key within the secret (default: "token")
|
||||
# Command prefix for bot commands (default: "!")
|
||||
invokeChar: "!"
|
||||
image:
|
||||
repository: onyxdotapp/onyx-backend
|
||||
tag: "" # Overrides the image tag whose default is the chart appVersion.
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend
|
||||
app: discord-bot
|
||||
deploymentLabels:
|
||||
app: discord-bot
|
||||
podSecurityContext:
|
||||
{}
|
||||
securityContext:
|
||||
{}
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2000Mi"
|
||||
volumes: []
|
||||
volumeMounts: []
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
slackbot:
|
||||
enabled: true
|
||||
replicaCount: 1
|
||||
@@ -1159,6 +1197,8 @@ configMap:
|
||||
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
|
||||
NOTIFY_SLACKBOT_NO_ANSWER: ""
|
||||
DISCORD_BOT_TOKEN: ""
|
||||
DISCORD_BOT_INVOKE_CHAR: ""
|
||||
# Logging
|
||||
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
|
||||
DISABLE_TELEMETRY: ""
|
||||
|
||||
3
desktop/.gitignore
vendored
3
desktop/.gitignore
vendored
@@ -22,3 +22,6 @@ npm-debug.log*
|
||||
# Local env files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Generated files
|
||||
src-tauri/gen/schemas/acl-manifests.json
|
||||
|
||||
96
desktop/src-tauri/Cargo.lock
generated
96
desktop/src-tauri/Cargo.lock
generated
@@ -706,16 +706,6 @@ dependencies = [
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -993,16 +983,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -1122,24 +1102,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "global-hotkey"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"keyboard-types",
|
||||
"objc2 0.6.3",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"windows-sys 0.59.0",
|
||||
"x11rb",
|
||||
"xkeysym",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -1713,12 +1675,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
@@ -2248,7 +2204,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-global-shortcut",
|
||||
"tauri-plugin-shell",
|
||||
"tauri-plugin-window-state",
|
||||
"tokio",
|
||||
@@ -2878,19 +2833,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -3605,21 +3547,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-global-shortcut"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
|
||||
dependencies = [
|
||||
"global-hotkey",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-shell"
|
||||
version = "2.3.3"
|
||||
@@ -5021,29 +4948,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xkeysym"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-global-shortcut = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -20,7 +20,6 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shortcuts Setup
|
||||
// ============================================================================
|
||||
|
||||
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
|
||||
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
|
||||
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
// Avoid hijacking the system-wide Cmd+R on macOS.
|
||||
#[cfg(target_os = "macos")]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
reload,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
shortcuts,
|
||||
move |_app, shortcut, _event| {
|
||||
if shortcut == &new_chat {
|
||||
trigger_new_chat(&app_handle);
|
||||
}
|
||||
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if shortcut == &reload {
|
||||
let _ = window.eval("window.location.reload()");
|
||||
} else if shortcut == &back {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if shortcut == &new_window_shortcut {
|
||||
trigger_new_window(&app_handle);
|
||||
} else if shortcut == &show_app {
|
||||
focus_main_window(&app_handle);
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+Space"),
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
@@ -666,7 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -698,11 +629,6 @@ fn main() {
|
||||
.setup(move |app| {
|
||||
let app_handle = app.handle();
|
||||
|
||||
// Setup global shortcuts
|
||||
if let Err(e) = setup_shortcuts(&app_handle) {
|
||||
eprintln!("Failed to setup shortcuts: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = setup_app_menu(&app_handle) {
|
||||
eprintln!("Failed to setup menu: {}", e);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background-900: #1a1a1a;
|
||||
--background-800: #262626;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.08);
|
||||
--white-15: rgba(255, 255, 255, 0.12);
|
||||
--white-20: rgba(255, 255, 255, 0.15);
|
||||
--white-30: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
@@ -30,7 +41,11 @@
|
||||
|
||||
body {
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
@@ -39,6 +54,9 @@
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
color 0.3s ease;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
@@ -69,16 +87,19 @@
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
background: var(--background-800);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
border 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .settings-panel {
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
@@ -93,17 +114,19 @@
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
background: var(--background-900);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
color: var(--text-light-05);
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
@@ -134,9 +157,10 @@
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
background: var(--background-900);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
@@ -176,7 +200,7 @@
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-800);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
@@ -186,8 +210,8 @@
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-900);
|
||||
box-shadow: 0 0 0 2px var(--white-10);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
@@ -231,7 +255,7 @@
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
background-color: var(--white-15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
@@ -243,14 +267,18 @@
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
background-color: var(--background-800);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.dark .toggle-slider:before {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
background-color: var(--white-30);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
@@ -288,14 +316,15 @@
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-15);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
@@ -372,10 +401,34 @@
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
// Theme detection based on system preferences
|
||||
function applySystemTheme() {
|
||||
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
||||
|
||||
function updateTheme(e) {
|
||||
if (e.matches) {
|
||||
document.documentElement.classList.add("dark");
|
||||
document.body.classList.add("dark");
|
||||
} else {
|
||||
document.documentElement.classList.remove("dark");
|
||||
document.body.classList.remove("dark");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply initial theme
|
||||
updateTheme(darkModeQuery);
|
||||
|
||||
// Listen for changes
|
||||
darkModeQuery.addEventListener("change", updateTheme);
|
||||
}
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Apply system theme immediately
|
||||
applySystemTheme();
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
|
||||
@@ -113,6 +113,23 @@
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function updateTitleBarTheme(isDark) {
|
||||
const titleBar = document.getElementById(TITLEBAR_ID);
|
||||
if (!titleBar) return;
|
||||
|
||||
if (isDark) {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
|
||||
} else {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
|
||||
}
|
||||
}
|
||||
|
||||
function buildTitleBar() {
|
||||
const titleBar = document.createElement("div");
|
||||
titleBar.id = TITLEBAR_ID;
|
||||
@@ -134,6 +151,11 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Apply initial styles matching current theme
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
// Apply styles matching Onyx design system with translucent glass effect
|
||||
titleBar.style.cssText = `
|
||||
position: fixed;
|
||||
@@ -156,8 +178,12 @@
|
||||
-webkit-backdrop-filter: blur(18px) saturate(180%);
|
||||
-webkit-app-region: drag;
|
||||
padding: 0 12px;
|
||||
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
|
||||
`;
|
||||
|
||||
// Apply correct theme
|
||||
updateTitleBarTheme(isDark);
|
||||
|
||||
return titleBar;
|
||||
}
|
||||
|
||||
@@ -168,6 +194,11 @@
|
||||
|
||||
const existing = document.getElementById(TITLEBAR_ID);
|
||||
if (existing?.parentElement === document.body) {
|
||||
// Update theme on existing titlebar
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,6 +209,14 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
|
||||
// Ensure theme is applied immediately after mount
|
||||
setTimeout(() => {
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
@@ -194,9 +233,66 @@
|
||||
}
|
||||
}
|
||||
|
||||
function observeThemeChanges() {
|
||||
let lastKnownTheme = null;
|
||||
|
||||
function checkAndUpdateTheme() {
|
||||
// Check both html and body for dark class (some apps use body)
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
if (lastKnownTheme !== isDark) {
|
||||
lastKnownTheme = isDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}
|
||||
}
|
||||
|
||||
// Immediate check on setup
|
||||
checkAndUpdateTheme();
|
||||
|
||||
// Watch for theme changes on the HTML element
|
||||
const themeObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
|
||||
themeObserver.observe(document.documentElement, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
|
||||
// Also observe body if it exists
|
||||
if (document.body) {
|
||||
const bodyObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
bodyObserver.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
}
|
||||
|
||||
// Also check periodically in case classList is manipulated directly
|
||||
// or the theme loads asynchronously after page load
|
||||
const intervalId = setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 300);
|
||||
|
||||
// Clean up after 30 seconds once theme should be stable
|
||||
setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
// But keep checking every 2 seconds for manual theme changes
|
||||
setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 2000);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
function init() {
|
||||
mountTitleBar();
|
||||
syncViewportHeight();
|
||||
observeThemeChanges();
|
||||
|
||||
window.addEventListener("resize", syncViewportHeight, { passive: true });
|
||||
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
|
||||
passive: true,
|
||||
|
||||
21
extensions/chrome/LICENSE
Normal file
21
extensions/chrome/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DanswerAI, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
30
extensions/chrome/README.md
Normal file
30
extensions/chrome/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Onyx Chrome Extension
|
||||
|
||||
The Onyx chrome extension lets you research, create, and automate with LLMs powered by your team's unique knowledge. Just hit Ctrl + O on Mac or Alt + O on Windows to instantly access Onyx in your browser:
|
||||
|
||||
💡 Know what your company knows, instantly with the Onyx sidebar
|
||||
💬 Chat: Onyx provides a natural language chat interface as the main way of interacting with the features.
|
||||
🌎 Internal Search: Ask questions and get answers from all your team's knowledge, powered by Onyx's 50+ connectors to all the tools your team uses
|
||||
🚀 With a simple Ctrl + O on Mac or Alt + O on Windows - instantly summarize information from any work application
|
||||
|
||||
⚡️ Get quick access to the work resources you need.
|
||||
🆕 Onyx new tab page puts all of your company’s knowledge at your fingertips
|
||||
🤖 Access custom AI Agents for unique use cases, and give them access to tools to take action.
|
||||
|
||||
—
|
||||
|
||||
Onyx connects with dozens of popular workplace apps like Google Drive, Jira, Confluence, Slack, and more. Use this extension if you have an account created by your team admin.
|
||||
|
||||
## Installation
|
||||
|
||||
For Onyx Cloud Users, please visit the Chrome Plugin Store (pending approval still)
|
||||
|
||||
## Development
|
||||
|
||||
- Load unpacked extension in your browser
|
||||
- Modify files in `src` directory
|
||||
- Refresh extension in Chrome
|
||||
|
||||
## Contributing
|
||||
|
||||
Submit issues or pull requests for improvements
|
||||
70
extensions/chrome/manifest.json
Normal file
70
extensions/chrome/manifest.json
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "Onyx",
|
||||
"version": "1.0",
|
||||
"description": "Onyx lets you research, create, and automate with LLMs powered by your team's unique knowledge",
|
||||
"permissions": [
|
||||
"sidePanel",
|
||||
"storage",
|
||||
"activeTab",
|
||||
"tabs"
|
||||
],
|
||||
"host_permissions": ["<all_urls>"],
|
||||
"background": {
|
||||
"service_worker": "service_worker.js",
|
||||
"type": "module"
|
||||
},
|
||||
"action": {
|
||||
"default_icon": {
|
||||
"16": "public/icon16.png",
|
||||
"48": "public/icon48.png",
|
||||
"128": "public/icon128.png"
|
||||
},
|
||||
"default_popup": "src/pages/popup.html"
|
||||
},
|
||||
"icons": {
|
||||
"16": "public/icon16.png",
|
||||
"48": "public/icon48.png",
|
||||
"128": "public/icon128.png"
|
||||
},
|
||||
"options_page": "src/pages/options.html",
|
||||
"chrome_url_overrides": {
|
||||
"newtab": "src/pages/onyx_home.html"
|
||||
},
|
||||
"commands": {
|
||||
"toggleNewTabOverride": {
|
||||
"suggested_key": {
|
||||
"default": "Ctrl+Shift+O",
|
||||
"mac": "Command+Shift+O"
|
||||
},
|
||||
"description": "Toggle Onyx New Tab Override"
|
||||
},
|
||||
"openSidePanel": {
|
||||
"suggested_key": {
|
||||
"default": "Ctrl+O",
|
||||
"windows": "Alt+O",
|
||||
"mac": "MacCtrl+O"
|
||||
},
|
||||
"description": "Open Onyx Side Panel"
|
||||
}
|
||||
},
|
||||
"side_panel": {
|
||||
"default_path": "src/pages/panel.html"
|
||||
},
|
||||
"omnibox": {
|
||||
"keyword": "onyx"
|
||||
},
|
||||
"content_scripts": [
|
||||
{
|
||||
"matches": ["<all_urls>"],
|
||||
"js": ["src/utils/selection-icon.js"],
|
||||
"css": ["src/styles/selection-icon.css"]
|
||||
}
|
||||
],
|
||||
"web_accessible_resources": [
|
||||
{
|
||||
"resources": ["public/icon32.png"],
|
||||
"matches": ["<all_urls>"]
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
extensions/chrome/public/icon128.png
Normal file
BIN
extensions/chrome/public/icon128.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
BIN
extensions/chrome/public/icon16.png
Normal file
BIN
extensions/chrome/public/icon16.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 235 B |
BIN
extensions/chrome/public/icon32.png
Normal file
BIN
extensions/chrome/public/icon32.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 551 B |
BIN
extensions/chrome/public/icon48.png
Normal file
BIN
extensions/chrome/public/icon48.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 KiB |
BIN
extensions/chrome/public/logo.png
Normal file
BIN
extensions/chrome/public/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.6 KiB |
276
extensions/chrome/service_worker.js
Normal file
276
extensions/chrome/service_worker.js
Normal file
@@ -0,0 +1,276 @@
|
||||
import {
|
||||
DEFAULT_ONYX_DOMAIN,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS,
|
||||
ACTIONS,
|
||||
SIDE_PANEL_PATH,
|
||||
} from "./src/utils/constants.js";
|
||||
|
||||
// Track side panel state per window
|
||||
const sidePanelOpenState = new Map();
|
||||
|
||||
// Open welcome page on first install
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason === "install") {
|
||||
chrome.storage.local.get(
|
||||
{ [CHROME_SPECIFIC_STORAGE_KEYS.ONBOARDING_COMPLETE]: false },
|
||||
(result) => {
|
||||
if (!result[CHROME_SPECIFIC_STORAGE_KEYS.ONBOARDING_COMPLETE]) {
|
||||
chrome.tabs.create({ url: "src/pages/welcome.html" });
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
async function setupSidePanel() {
|
||||
if (chrome.sidePanel) {
|
||||
try {
|
||||
// Don't auto-open side panel on action click since we have a popup menu
|
||||
await chrome.sidePanel.setPanelBehavior({
|
||||
openPanelOnActionClick: false,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error setting up side panel:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function openSidePanel(tabId) {
|
||||
try {
|
||||
await chrome.sidePanel.open({ tabId });
|
||||
} catch (error) {
|
||||
console.error("Error opening side panel:", error);
|
||||
}
|
||||
}
|
||||
|
||||
async function sendToOnyx(info, tab) {
|
||||
const selectedText = encodeURIComponent(info.selectionText);
|
||||
const currentUrl = encodeURIComponent(tab.url);
|
||||
|
||||
try {
|
||||
const result = await chrome.storage.local.get({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN,
|
||||
});
|
||||
const url = `${
|
||||
result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]
|
||||
}${SIDE_PANEL_PATH}?user-prompt=${selectedText}`;
|
||||
|
||||
await openSidePanel(tab.id);
|
||||
chrome.runtime.sendMessage({
|
||||
action: ACTIONS.OPEN_SIDE_PANEL_WITH_INPUT,
|
||||
url: url,
|
||||
pageUrl: tab.url,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error sending to Onyx:", error);
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleNewTabOverride() {
|
||||
try {
|
||||
const result = await chrome.storage.local.get(
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB,
|
||||
);
|
||||
const newValue =
|
||||
!result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB];
|
||||
await chrome.storage.local.set({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: newValue,
|
||||
});
|
||||
|
||||
chrome.notifications.create({
|
||||
type: "basic",
|
||||
iconUrl: "icon.png",
|
||||
title: "Onyx New Tab",
|
||||
message: `New Tab Override ${newValue ? "enabled" : "disabled"}`,
|
||||
});
|
||||
|
||||
// Send a message to inform all tabs about the change
|
||||
chrome.tabs.query({}, (tabs) => {
|
||||
tabs.forEach((tab) => {
|
||||
chrome.tabs.sendMessage(tab.id, {
|
||||
action: "newTabOverrideToggled",
|
||||
value: newValue,
|
||||
});
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error toggling new tab override:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: This listener won't fire when a popup is defined in manifest.json
|
||||
// The popup will show instead. This is kept as a fallback if popup is removed.
|
||||
chrome.action.onClicked.addListener((tab) => {
|
||||
openSidePanel(tab.id);
|
||||
});
|
||||
|
||||
chrome.commands.onCommand.addListener(async (command) => {
|
||||
if (command === ACTIONS.SEND_TO_ONYX) {
|
||||
try {
|
||||
const [tab] = await chrome.tabs.query({
|
||||
active: true,
|
||||
lastFocusedWindow: true,
|
||||
});
|
||||
if (tab) {
|
||||
const response = await chrome.tabs.sendMessage(tab.id, {
|
||||
action: ACTIONS.GET_SELECTED_TEXT,
|
||||
});
|
||||
const selectedText = response?.selectedText || "";
|
||||
sendToOnyx({ selectionText: selectedText }, tab);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error sending to Onyx:", error);
|
||||
}
|
||||
} else if (command === ACTIONS.TOGGLE_NEW_TAB_OVERRIDE) {
|
||||
toggleNewTabOverride();
|
||||
} else if (command === ACTIONS.CLOSE_SIDE_PANEL) {
|
||||
try {
|
||||
await chrome.sidePanel.hide();
|
||||
} catch (error) {
|
||||
console.error("Error closing side panel via command:", error);
|
||||
}
|
||||
} else if (command === ACTIONS.OPEN_SIDE_PANEL) {
|
||||
chrome.tabs.query({ active: true, lastFocusedWindow: true }, (tabs) => {
|
||||
if (tabs && tabs.length > 0) {
|
||||
const tab = tabs[0];
|
||||
const windowId = tab.windowId;
|
||||
const isOpen = sidePanelOpenState.get(windowId) || false;
|
||||
|
||||
if (isOpen) {
|
||||
chrome.sidePanel.setOptions({ enabled: false }, () => {
|
||||
chrome.sidePanel.setOptions({ enabled: true });
|
||||
sidePanelOpenState.set(windowId, false);
|
||||
});
|
||||
} else {
|
||||
chrome.sidePanel.open({ tabId: tab.id });
|
||||
sidePanelOpenState.set(windowId, true);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
console.log("Unhandled command:", command);
|
||||
}
|
||||
});
|
||||
|
||||
chrome.runtime.onMessage.addListener((request, sender, sendResponse) => {
|
||||
if (request.action === ACTIONS.GET_CURRENT_ONYX_DOMAIN) {
|
||||
chrome.storage.local.get(
|
||||
{ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN },
|
||||
(result) => {
|
||||
sendResponse({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]:
|
||||
result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN],
|
||||
});
|
||||
},
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if (request.action === ACTIONS.CLOSE_SIDE_PANEL) {
|
||||
closeSidePanel();
|
||||
chrome.storage.local.get(
|
||||
{ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN },
|
||||
(result) => {
|
||||
chrome.tabs.create({
|
||||
url: `${result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]}/auth/login`,
|
||||
active: true,
|
||||
});
|
||||
},
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if (request.action === ACTIONS.OPEN_SIDE_PANEL_WITH_INPUT) {
|
||||
const { selectedText, pageUrl } = request;
|
||||
const tabId = sender.tab?.id;
|
||||
const windowId = sender.tab?.windowId;
|
||||
|
||||
if (tabId && windowId) {
|
||||
chrome.storage.local.get(
|
||||
{ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN },
|
||||
(result) => {
|
||||
const encodedText = encodeURIComponent(selectedText);
|
||||
const onyxDomain = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN];
|
||||
const url = `${onyxDomain}${SIDE_PANEL_PATH}?user-prompt=${encodedText}`;
|
||||
|
||||
chrome.storage.session.set({
|
||||
pendingInput: {
|
||||
url: url,
|
||||
pageUrl: pageUrl,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
|
||||
chrome.sidePanel
|
||||
.open({ windowId })
|
||||
.then(() => {
|
||||
chrome.runtime.sendMessage({
|
||||
action: ACTIONS.OPEN_ONYX_WITH_INPUT,
|
||||
url: url,
|
||||
pageUrl: pageUrl,
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
"[Onyx SW] Error opening side panel with text:",
|
||||
error,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
} else {
|
||||
console.error("[Onyx SW] Missing tabId or windowId");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
});
|
||||
|
||||
chrome.storage.onChanged.addListener((changes, namespace) => {
|
||||
if (
|
||||
namespace === "local" &&
|
||||
changes[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]
|
||||
) {
|
||||
const newValue =
|
||||
changes[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]
|
||||
.newValue;
|
||||
|
||||
if (newValue === false) {
|
||||
chrome.runtime.openOptionsPage();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
chrome.windows.onRemoved.addListener((windowId) => {
|
||||
sidePanelOpenState.delete(windowId);
|
||||
});
|
||||
|
||||
chrome.omnibox.setDefaultSuggestion({
|
||||
description: 'Search Onyx for "%s"',
|
||||
});
|
||||
|
||||
chrome.omnibox.onInputEntered.addListener(async (text) => {
|
||||
try {
|
||||
const result = await chrome.storage.local.get({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN,
|
||||
});
|
||||
|
||||
const domain = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN];
|
||||
const searchUrl = `${domain}/chat?user-prompt=${encodeURIComponent(text)}`;
|
||||
|
||||
chrome.tabs.update({ url: searchUrl });
|
||||
} catch (error) {
|
||||
console.error("Error handling omnibox search:", error);
|
||||
}
|
||||
});
|
||||
|
||||
chrome.omnibox.onInputChanged.addListener((text, suggest) => {
|
||||
if (text.trim()) {
|
||||
suggest([
|
||||
{
|
||||
content: text,
|
||||
description: `Search Onyx for "<match>${text}</match>"`,
|
||||
},
|
||||
]);
|
||||
}
|
||||
});
|
||||
|
||||
setupSidePanel();
|
||||
BIN
extensions/chrome/src/.DS_Store
vendored
Normal file
BIN
extensions/chrome/src/.DS_Store
vendored
Normal file
Binary file not shown.
76
extensions/chrome/src/pages/onyx_home.html
Normal file
76
extensions/chrome/src/pages/onyx_home.html
Normal file
@@ -0,0 +1,76 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta http-equiv="Permissions-Policy" content="clipboard-write=(self)" />
|
||||
<title>Onyx Home</title>
|
||||
<link rel="stylesheet" href="../styles/shared.css" />
|
||||
<style>
|
||||
body,
|
||||
html {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
width: 100%;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
html,
|
||||
body {
|
||||
background-color: #000;
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: light) {
|
||||
html,
|
||||
body {
|
||||
background-color: #f6f6f6;
|
||||
}
|
||||
}
|
||||
|
||||
#background {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-size: cover;
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
transition: opacity 0.5s ease-in-out;
|
||||
}
|
||||
|
||||
#content {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
opacity: 0;
|
||||
transition: opacity 0.5s ease-in-out;
|
||||
}
|
||||
|
||||
iframe {
|
||||
border: none;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
visibility: hidden;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="background"></div>
|
||||
<div id="content">
|
||||
<iframe
|
||||
id="onyx-iframe"
|
||||
allowfullscreen
|
||||
allow="clipboard-read; clipboard-write"
|
||||
></iframe>
|
||||
</div>
|
||||
<script src="onyx_home.js" type="module"></script>
|
||||
</body>
|
||||
</html>
|
||||
248
extensions/chrome/src/pages/onyx_home.js
Normal file
248
extensions/chrome/src/pages/onyx_home.js
Normal file
@@ -0,0 +1,248 @@
|
||||
import {
|
||||
CHROME_MESSAGE,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS,
|
||||
WEB_MESSAGE,
|
||||
} from "../utils/constants.js";
|
||||
import {
|
||||
showErrorModal,
|
||||
hideErrorModal,
|
||||
initErrorModal,
|
||||
} from "../utils/error-modal.js";
|
||||
import { getOnyxDomain } from "../utils/storage.js";
|
||||
|
||||
(function () {
|
||||
let mainIframe = document.getElementById("onyx-iframe");
|
||||
let preloadedIframe = null;
|
||||
const background = document.getElementById("background");
|
||||
const content = document.getElementById("content");
|
||||
const DEFAULT_LIGHT_BACKGROUND_IMAGE =
|
||||
"https://images.unsplash.com/photo-1692520883599-d543cfe6d43d?q=80&w=2666&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D";
|
||||
const DEFAULT_DARK_BACKGROUND_IMAGE =
|
||||
"https://images.unsplash.com/photo-1692520883599-d543cfe6d43d?q=80&w=2666&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D";
|
||||
|
||||
let iframeLoadTimeout;
|
||||
let iframeLoaded = false;
|
||||
|
||||
initErrorModal();
|
||||
|
||||
async function preloadChatInterface() {
|
||||
preloadedIframe = document.createElement("iframe");
|
||||
|
||||
const domain = await getOnyxDomain();
|
||||
preloadedIframe.src = domain + "/chat";
|
||||
preloadedIframe.style.opacity = "0";
|
||||
preloadedIframe.style.visibility = "hidden";
|
||||
preloadedIframe.style.transition = "opacity 0.3s ease-in";
|
||||
preloadedIframe.style.border = "none";
|
||||
preloadedIframe.style.width = "100%";
|
||||
preloadedIframe.style.height = "100%";
|
||||
preloadedIframe.style.position = "absolute";
|
||||
preloadedIframe.style.top = "0";
|
||||
preloadedIframe.style.left = "0";
|
||||
preloadedIframe.style.zIndex = "1";
|
||||
content.appendChild(preloadedIframe);
|
||||
}
|
||||
|
||||
function setIframeSrc(url) {
|
||||
mainIframe.src = url;
|
||||
startIframeLoadTimeout();
|
||||
iframeLoaded = false;
|
||||
}
|
||||
|
||||
function startIframeLoadTimeout() {
|
||||
clearTimeout(iframeLoadTimeout);
|
||||
iframeLoadTimeout = setTimeout(() => {
|
||||
if (!iframeLoaded) {
|
||||
try {
|
||||
if (
|
||||
mainIframe.contentWindow.location.pathname.includes("/auth/login")
|
||||
) {
|
||||
showLoginPage();
|
||||
} else {
|
||||
showErrorModal(mainIframe.src);
|
||||
}
|
||||
} catch (error) {
|
||||
showErrorModal(mainIframe.src);
|
||||
}
|
||||
}
|
||||
}, 2500);
|
||||
}
|
||||
|
||||
function showLoginPage() {
|
||||
background.style.opacity = "0";
|
||||
mainIframe.style.opacity = "1";
|
||||
mainIframe.style.visibility = "visible";
|
||||
content.style.opacity = "1";
|
||||
hideErrorModal();
|
||||
}
|
||||
|
||||
function setTheme(theme, customBackgroundImage) {
|
||||
const imageUrl =
|
||||
customBackgroundImage ||
|
||||
(theme === "dark"
|
||||
? DEFAULT_DARK_BACKGROUND_IMAGE
|
||||
: DEFAULT_LIGHT_BACKGROUND_IMAGE);
|
||||
background.style.backgroundImage = `url('${imageUrl}')`;
|
||||
}
|
||||
|
||||
function fadeInContent() {
|
||||
content.style.transition = "opacity 0.5s ease-in";
|
||||
mainIframe.style.transition = "opacity 0.5s ease-in";
|
||||
content.style.opacity = "0";
|
||||
mainIframe.style.opacity = "0";
|
||||
mainIframe.style.visibility = "visible";
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
content.style.opacity = "1";
|
||||
mainIframe.style.opacity = "1";
|
||||
|
||||
setTimeout(() => {
|
||||
background.style.transition = "opacity 0.3s ease-out";
|
||||
background.style.opacity = "0";
|
||||
}, 500);
|
||||
});
|
||||
}
|
||||
|
||||
function checkOnyxPreference() {
|
||||
chrome.storage.local.get(
|
||||
[
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN,
|
||||
],
|
||||
(items) => {
|
||||
let useOnyxAsDefaultNewTab =
|
||||
items[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB];
|
||||
|
||||
if (useOnyxAsDefaultNewTab === undefined) {
|
||||
useOnyxAsDefaultNewTab = !!(
|
||||
localStorage.getItem(
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB,
|
||||
) === "1"
|
||||
);
|
||||
chrome.storage.local.set({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]:
|
||||
useOnyxAsDefaultNewTab,
|
||||
});
|
||||
}
|
||||
|
||||
if (!useOnyxAsDefaultNewTab) {
|
||||
chrome.tabs.update({
|
||||
url: "chrome://new-tab-page",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setIframeSrc(
|
||||
items[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN] + "/chat/nrf",
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function loadThemeAndBackground() {
|
||||
chrome.storage.local.get(
|
||||
[
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.THEME,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.DARK_BG_URL,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS.LIGHT_BG_URL,
|
||||
],
|
||||
function (result) {
|
||||
const theme = result[CHROME_SPECIFIC_STORAGE_KEYS.THEME] || "light";
|
||||
const customBackgroundImage =
|
||||
result[CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE];
|
||||
const darkBgUrl = result[CHROME_SPECIFIC_STORAGE_KEYS.DARK_BG_URL];
|
||||
const lightBgUrl = result[CHROME_SPECIFIC_STORAGE_KEYS.LIGHT_BG_URL];
|
||||
|
||||
let backgroundImage;
|
||||
if (customBackgroundImage) {
|
||||
backgroundImage = customBackgroundImage;
|
||||
} else if (theme === "dark" && darkBgUrl) {
|
||||
backgroundImage = darkBgUrl;
|
||||
} else if (theme === "light" && lightBgUrl) {
|
||||
backgroundImage = lightBgUrl;
|
||||
}
|
||||
|
||||
setTheme(theme, backgroundImage);
|
||||
checkOnyxPreference();
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function loadNewPage(newSrc) {
|
||||
if (preloadedIframe && preloadedIframe.contentWindow) {
|
||||
preloadedIframe.contentWindow.postMessage(
|
||||
{ type: WEB_MESSAGE.PAGE_CHANGE, href: newSrc },
|
||||
"*",
|
||||
);
|
||||
} else {
|
||||
console.error("Preloaded iframe not available");
|
||||
}
|
||||
}
|
||||
|
||||
function completePendingPageLoad() {
|
||||
if (preloadedIframe) {
|
||||
preloadedIframe.style.visibility = "visible";
|
||||
preloadedIframe.style.opacity = "1";
|
||||
preloadedIframe.style.zIndex = "1";
|
||||
mainIframe.style.zIndex = "2";
|
||||
mainIframe.style.opacity = "0";
|
||||
|
||||
setTimeout(() => {
|
||||
if (content.contains(mainIframe)) {
|
||||
content.removeChild(mainIframe);
|
||||
}
|
||||
|
||||
mainIframe = preloadedIframe;
|
||||
mainIframe.id = "onyx-iframe";
|
||||
mainIframe.style.zIndex = "";
|
||||
iframeLoaded = true;
|
||||
clearTimeout(iframeLoadTimeout);
|
||||
}, 200);
|
||||
} else {
|
||||
console.warn("No preloaded iframe available");
|
||||
}
|
||||
}
|
||||
|
||||
chrome.storage.onChanged.addListener(function (changes, namespace) {
|
||||
if (namespace === "local" && changes.useOnyxAsDefaultNewTab) {
|
||||
checkOnyxPreference();
|
||||
}
|
||||
});
|
||||
|
||||
window.addEventListener("message", function (event) {
|
||||
if (event.data.type === CHROME_MESSAGE.SET_DEFAULT_NEW_TAB) {
|
||||
chrome.storage.local.set({ useOnyxAsDefaultNewTab: event.data.value });
|
||||
} else if (event.data.type === CHROME_MESSAGE.ONYX_APP_LOADED) {
|
||||
clearTimeout(iframeLoadTimeout);
|
||||
hideErrorModal();
|
||||
fadeInContent();
|
||||
iframeLoaded = true;
|
||||
} else if (event.data.type === CHROME_MESSAGE.PREFERENCES_UPDATED) {
|
||||
const { theme, backgroundUrl } = event.data.payload;
|
||||
chrome.storage.local.set(
|
||||
{
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.THEME]: theme,
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE]: backgroundUrl,
|
||||
},
|
||||
() => {},
|
||||
);
|
||||
} else if (event.data.type === CHROME_MESSAGE.LOAD_NEW_PAGE) {
|
||||
loadNewPage(event.data.href);
|
||||
} else if (event.data.type === CHROME_MESSAGE.LOAD_NEW_CHAT_PAGE) {
|
||||
completePendingPageLoad();
|
||||
}
|
||||
});
|
||||
|
||||
mainIframe.onload = function () {
|
||||
clearTimeout(iframeLoadTimeout);
|
||||
startIframeLoadTimeout();
|
||||
};
|
||||
|
||||
mainIframe.onerror = function (error) {
|
||||
showErrorModal(mainIframe.src);
|
||||
};
|
||||
|
||||
loadThemeAndBackground();
|
||||
preloadChatInterface();
|
||||
})();
|
||||
515
extensions/chrome/src/pages/options.html
Normal file
515
extensions/chrome/src/pages/options.html
Normal file
@@ -0,0 +1,515 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta http-equiv="Permissions-Policy" content="clipboard-write=(self)" />
|
||||
<title>Onyx - Settings</title>
|
||||
<link rel="stylesheet" href="../styles/shared.css" />
|
||||
<style>
|
||||
:root {
|
||||
--background-900: #0a0a0a;
|
||||
--background-800: #1a1a1a;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.1);
|
||||
--white-15: rgba(255, 255, 255, 0.15);
|
||||
--white-20: rgba(255, 255, 255, 0.2);
|
||||
--white-30: rgba(255, 255, 255, 0.3);
|
||||
--white-40: rgba(255, 255, 255, 0.4);
|
||||
--white-80: rgba(255, 255, 255, 0.8);
|
||||
--black-40: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
body.light-theme {
|
||||
--background-900: #f5f5f5;
|
||||
--background-800: #ffffff;
|
||||
--text-light-05: rgba(0, 0, 0, 0.95);
|
||||
--text-light-03: rgba(0, 0, 0, 0.6);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
}
|
||||
|
||||
body.light-theme .settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
body.light-theme .settings-header {
|
||||
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
body.light-theme .settings-icon {
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
body.light-theme .theme-toggle {
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
body.light-theme .theme-toggle:hover {
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
|
||||
body.light-theme .theme-toggle svg {
|
||||
stroke: rgba(0, 0, 0, 0.95);
|
||||
}
|
||||
|
||||
body.light-theme .settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
}
|
||||
|
||||
body.light-theme .setting-divider {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
body.light-theme .input-field {
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
color: rgba(0, 0, 0, 0.95);
|
||||
}
|
||||
|
||||
body.light-theme .input-field:focus {
|
||||
outline: none;
|
||||
border-color: rgba(0, 0, 0, 0.25);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
color: rgba(0, 0, 0, 0.95);
|
||||
}
|
||||
|
||||
body.light-theme .status-container {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
}
|
||||
|
||||
body.light-theme .button.secondary {
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
color: rgba(0, 0, 0, 0.95);
|
||||
}
|
||||
|
||||
body.light-theme .button.secondary:hover {
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
|
||||
body.light-theme .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
body.light-theme input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
body.light-theme .toggle-slider:before {
|
||||
background-color: white;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.settings-container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
margin: 0 auto;
|
||||
padding: 40px 20px;
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(10, 10, 10, 0.95),
|
||||
rgba(26, 26, 26, 0.95)
|
||||
);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
padding: 24px;
|
||||
border-bottom: 1px solid var(--white-10);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.settings-header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.settings-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.settings-icon img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: contain;
|
||||
padding: 6px;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: var(--text-light-05);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.theme-toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 6px 12px;
|
||||
border-radius: 999px;
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-10);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.theme-toggle:hover {
|
||||
background: var(--white-15);
|
||||
}
|
||||
|
||||
.theme-toggle svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
stroke: var(--text-light-05);
|
||||
}
|
||||
|
||||
.settings-content {
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
.settings-section {
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
|
||||
.settings-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--text-light-03);
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.setting-row-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.setting-label {
|
||||
font-size: 14px;
|
||||
font-weight: 400;
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
.setting-description {
|
||||
font-size: 12px;
|
||||
color: var(--text-light-03);
|
||||
}
|
||||
|
||||
.setting-divider {
|
||||
height: 1px;
|
||||
background: var(--white-10);
|
||||
margin: 0 4px;
|
||||
}
|
||||
|
||||
.input-field {
|
||||
width: 100%;
|
||||
padding: 10px 12px;
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
box-shadow: 0 0 0 2px rgba(255, 255, 255, 0.1);
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
color: var(--text-light-03);
|
||||
}
|
||||
|
||||
.setting-row .input-field {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.toggle-switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 44px;
|
||||
height: 24px;
|
||||
}
|
||||
|
||||
.toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggle-slider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(255, 255, 255, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(255, 255, 255, 0.4);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
transform: translateX(20px);
|
||||
}
|
||||
|
||||
.status-container {
|
||||
margin-top: 20px;
|
||||
padding: 12px;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 8px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s;
|
||||
}
|
||||
|
||||
.status-container.show {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.status-message {
|
||||
margin: 0 0 12px 0;
|
||||
color: var(--text-light-05);
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.button {
|
||||
padding: 10px 20px;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
}
|
||||
|
||||
.button.secondary {
|
||||
background: var(--white-10);
|
||||
color: var(--text-light-05);
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.button.secondary:hover {
|
||||
background: var(--white-15);
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.settings-container {
|
||||
padding: 20px 16px;
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.settings-content {
|
||||
padding: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="settings-container">
|
||||
<div class="settings-panel">
|
||||
<div class="settings-header">
|
||||
<div class="settings-header-left">
|
||||
<div class="settings-icon">
|
||||
<img src="../../public/icon48.png" alt="Onyx" />
|
||||
</div>
|
||||
<h1 class="settings-title">Settings</h1>
|
||||
</div>
|
||||
<button
|
||||
class="theme-toggle"
|
||||
id="themeToggle"
|
||||
aria-label="Toggle theme"
|
||||
>
|
||||
<svg
|
||||
id="themeIcon"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<circle cx="12" cy="12" r="4"></circle>
|
||||
<path
|
||||
d="M12 2v2m0 16v2M4.93 4.93l1.41 1.41m11.32 11.32l1.41 1.41M2 12h2m16 0h2M4.93 19.07l1.41-1.41M17.66 6.34l1.41-1.41"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="settings-content">
|
||||
<!-- General Section -->
|
||||
<section class="settings-section">
|
||||
<div class="section-title">General</div>
|
||||
<div class="settings-group">
|
||||
<div class="setting-row">
|
||||
<div class="setting-row-content">
|
||||
<label class="setting-label" for="onyxDomain"
|
||||
>Root Domain</label
|
||||
>
|
||||
<div class="setting-description">
|
||||
The root URL for your Onyx instance
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-divider"></div>
|
||||
<div class="setting-row" style="padding: 12px">
|
||||
<input
|
||||
type="text"
|
||||
id="onyxDomain"
|
||||
class="input-field"
|
||||
placeholder="https://cloud.onyx.app"
|
||||
/>
|
||||
</div>
|
||||
<div class="setting-divider"></div>
|
||||
<div class="setting-row">
|
||||
<div class="setting-row-content">
|
||||
<label class="setting-label" for="useOnyxAsDefault"
|
||||
>Use Onyx as new tab page</label
|
||||
>
|
||||
</div>
|
||||
<label class="toggle-switch">
|
||||
<input type="checkbox" id="useOnyxAsDefault" />
|
||||
<span class="toggle-slider"></span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Search Engine Section -->
|
||||
<section class="settings-section">
|
||||
<div class="section-title">Search Engine</div>
|
||||
<div class="settings-group">
|
||||
<div class="setting-row">
|
||||
<div class="setting-row-content">
|
||||
<label class="setting-label">Use Onyx in Address Bar</label>
|
||||
<div class="setting-description">
|
||||
Type <kbd>onyx</kbd> followed by a space in Chrome's address
|
||||
bar, then enter your search query and press Enter
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-divider"></div>
|
||||
<div class="setting-row">
|
||||
<div class="setting-row-content">
|
||||
<div class="setting-description">
|
||||
Searches will be directed to your configured Onyx instance
|
||||
at the Root Domain above
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Status Message -->
|
||||
<div id="statusContainer" class="status-container">
|
||||
<p id="status" class="status-message"></p>
|
||||
<button id="newTab" class="button secondary" style="display: none">
|
||||
Open New Tab to Test
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script type="module" src="options.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
142
extensions/chrome/src/pages/options.js
Normal file
142
extensions/chrome/src/pages/options.js
Normal file
@@ -0,0 +1,142 @@
|
||||
import {
|
||||
CHROME_SPECIFIC_STORAGE_KEYS,
|
||||
DEFAULT_ONYX_DOMAIN,
|
||||
} from "../utils/constants.js";
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const domainInput = document.getElementById("onyxDomain");
|
||||
const useOnyxAsDefaultToggle = document.getElementById("useOnyxAsDefault");
|
||||
const statusContainer = document.getElementById("statusContainer");
|
||||
const statusElement = document.getElementById("status");
|
||||
const newTabButton = document.getElementById("newTab");
|
||||
const themeToggle = document.getElementById("themeToggle");
|
||||
const themeIcon = document.getElementById("themeIcon");
|
||||
|
||||
let currentTheme = "dark";
|
||||
|
||||
function updateThemeIcon(theme) {
|
||||
if (!themeIcon) return;
|
||||
|
||||
if (theme === "light") {
|
||||
themeIcon.innerHTML = `
|
||||
<circle cx="12" cy="12" r="4"></circle>
|
||||
<path d="M12 2v2m0 16v2M4.93 4.93l1.41 1.41m11.32 11.32l1.41 1.41M2 12h2m16 0h2M4.93 19.07l1.41-1.41M17.66 6.34l1.41-1.41"></path>
|
||||
`;
|
||||
} else {
|
||||
themeIcon.innerHTML = `
|
||||
<path d="M21 12.79A9 9 0 1 1 11.21 3 7 7 0 0 0 21 12.79z"></path>
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
function loadStoredValues() {
|
||||
chrome.storage.local.get(
|
||||
{
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN,
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: false,
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.THEME]: "dark",
|
||||
},
|
||||
(result) => {
|
||||
if (domainInput)
|
||||
domainInput.value = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN];
|
||||
if (useOnyxAsDefaultToggle)
|
||||
useOnyxAsDefaultToggle.checked =
|
||||
result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB];
|
||||
|
||||
currentTheme = result[CHROME_SPECIFIC_STORAGE_KEYS.THEME] || "dark";
|
||||
updateThemeIcon(currentTheme);
|
||||
|
||||
document.body.className = currentTheme === "light" ? "light-theme" : "";
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function saveSettings() {
|
||||
const domain = domainInput.value.trim();
|
||||
const useOnyxAsDefault = useOnyxAsDefaultToggle
|
||||
? useOnyxAsDefaultToggle.checked
|
||||
: false;
|
||||
|
||||
chrome.storage.local.set(
|
||||
{
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: domain,
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]:
|
||||
useOnyxAsDefault,
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme,
|
||||
},
|
||||
() => {
|
||||
showStatusMessage(
|
||||
useOnyxAsDefault
|
||||
? "Settings updated. Open a new tab to test it out. Click on the extension icon to bring up Onyx from any page."
|
||||
: "Settings updated.",
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function showStatusMessage(message) {
|
||||
if (statusElement) {
|
||||
const useOnyxAsDefault = useOnyxAsDefaultToggle
|
||||
? useOnyxAsDefaultToggle.checked
|
||||
: false;
|
||||
|
||||
statusElement.textContent =
|
||||
message ||
|
||||
(useOnyxAsDefault
|
||||
? "Settings updated. Open a new tab to test it out. Click on the extension icon to bring up Onyx from any page."
|
||||
: "Settings updated.");
|
||||
|
||||
if (newTabButton) {
|
||||
newTabButton.style.display = useOnyxAsDefault ? "block" : "none";
|
||||
}
|
||||
}
|
||||
|
||||
if (statusContainer) {
|
||||
statusContainer.classList.add("show");
|
||||
}
|
||||
|
||||
setTimeout(hideStatusMessage, 5000);
|
||||
}
|
||||
|
||||
function hideStatusMessage() {
|
||||
if (statusContainer) {
|
||||
statusContainer.classList.remove("show");
|
||||
}
|
||||
}
|
||||
|
||||
function toggleTheme() {
|
||||
currentTheme = currentTheme === "light" ? "dark" : "light";
|
||||
updateThemeIcon(currentTheme);
|
||||
|
||||
document.body.className = currentTheme === "light" ? "light-theme" : "";
|
||||
|
||||
chrome.storage.local.set({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme,
|
||||
});
|
||||
}
|
||||
|
||||
function openNewTab() {
|
||||
chrome.tabs.create({});
|
||||
}
|
||||
|
||||
if (domainInput) {
|
||||
domainInput.addEventListener("input", () => {
|
||||
clearTimeout(domainInput.saveTimeout);
|
||||
domainInput.saveTimeout = setTimeout(saveSettings, 1000);
|
||||
});
|
||||
}
|
||||
|
||||
if (useOnyxAsDefaultToggle) {
|
||||
useOnyxAsDefaultToggle.addEventListener("change", saveSettings);
|
||||
}
|
||||
|
||||
if (themeToggle) {
|
||||
themeToggle.addEventListener("click", toggleTheme);
|
||||
}
|
||||
|
||||
if (newTabButton) {
|
||||
newTabButton.addEventListener("click", openNewTab);
|
||||
}
|
||||
|
||||
loadStoredValues();
|
||||
});
|
||||
91
extensions/chrome/src/pages/panel.html
Normal file
91
extensions/chrome/src/pages/panel.html
Normal file
@@ -0,0 +1,91 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta http-equiv="Permissions-Policy" content="clipboard-write=(self)" />
|
||||
<title>Onyx Panel</title>
|
||||
<link rel="stylesheet" href="../styles/shared.css" />
|
||||
<style>
|
||||
body,
|
||||
html {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
width: 100%;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
#loading-screen {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: #f5f5f5;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 1000;
|
||||
transition: opacity 0.5s ease-in-out;
|
||||
}
|
||||
|
||||
#logo {
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
background-image: url("/public/logo.png");
|
||||
background-size: contain;
|
||||
background-repeat: no-repeat;
|
||||
background-position: center;
|
||||
animation: pulse 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
}
|
||||
|
||||
50% {
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
100% {
|
||||
transform: scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
#loading-text {
|
||||
color: #0a0a0a;
|
||||
margin-top: 20px;
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
iframe {
|
||||
border: none;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
opacity: 0;
|
||||
transition: opacity 0.5s ease-in-out;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="loading-screen">
|
||||
<div id="logo"></div>
|
||||
<div id="loading-text">Loading Onyx...</div>
|
||||
</div>
|
||||
<iframe
|
||||
id="onyx-panel-iframe"
|
||||
allow="clipboard-read; clipboard-write"
|
||||
></iframe>
|
||||
<script src="../utils/error-modal.js" type="module"></script>
|
||||
<script src="panel.js" type="module"></script>
|
||||
</body>
|
||||
</html>
|
||||
127
extensions/chrome/src/pages/panel.js
Normal file
127
extensions/chrome/src/pages/panel.js
Normal file
@@ -0,0 +1,127 @@
|
||||
import { showErrorModal, showAuthModal } from "../utils/error-modal.js";
|
||||
import {
|
||||
ACTIONS,
|
||||
CHROME_MESSAGE,
|
||||
WEB_MESSAGE,
|
||||
CHROME_SPECIFIC_STORAGE_KEYS,
|
||||
SIDE_PANEL_PATH,
|
||||
} from "../utils/constants.js";
|
||||
(function () {
|
||||
const iframe = document.getElementById("onyx-panel-iframe");
|
||||
const loadingScreen = document.getElementById("loading-screen");
|
||||
|
||||
let currentUrl = "";
|
||||
let iframeLoaded = false;
|
||||
let iframeLoadTimeout;
|
||||
let authRequired = false;
|
||||
|
||||
async function checkPendingInput() {
|
||||
try {
|
||||
const result = await chrome.storage.session.get("pendingInput");
|
||||
if (result.pendingInput) {
|
||||
const { url, pageUrl, timestamp } = result.pendingInput;
|
||||
if (Date.now() - timestamp < 5000) {
|
||||
setIframeSrc(url, pageUrl);
|
||||
await chrome.storage.session.remove("pendingInput");
|
||||
return true;
|
||||
}
|
||||
await chrome.storage.session.remove("pendingInput");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("[Onyx Panel] Error checking pending input:", error);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function initializePanel() {
|
||||
loadingScreen.style.display = "flex";
|
||||
loadingScreen.style.opacity = "1";
|
||||
iframe.style.opacity = "0";
|
||||
|
||||
// Check for pending input first (from selection icon click)
|
||||
const hasPendingInput = await checkPendingInput();
|
||||
if (!hasPendingInput) {
|
||||
loadOnyxDomain();
|
||||
}
|
||||
}
|
||||
|
||||
function setIframeSrc(url, pageUrl) {
|
||||
iframe.src = url;
|
||||
currentUrl = pageUrl;
|
||||
}
|
||||
|
||||
function sendWebsiteToIframe(pageUrl) {
|
||||
if (iframe.contentWindow && pageUrl !== currentUrl) {
|
||||
iframe.contentWindow.postMessage(
|
||||
{
|
||||
type: WEB_MESSAGE.PAGE_CHANGE,
|
||||
url: pageUrl,
|
||||
},
|
||||
"*",
|
||||
);
|
||||
currentUrl = pageUrl;
|
||||
}
|
||||
}
|
||||
|
||||
function startIframeLoadTimeout() {
|
||||
iframeLoadTimeout = setTimeout(() => {
|
||||
if (!iframeLoaded) {
|
||||
if (authRequired) {
|
||||
showAuthModal();
|
||||
} else {
|
||||
showErrorModal(iframe.src);
|
||||
}
|
||||
}
|
||||
}, 2500);
|
||||
}
|
||||
|
||||
function handleMessage(event) {
|
||||
if (event.data.type === CHROME_MESSAGE.ONYX_APP_LOADED) {
|
||||
clearTimeout(iframeLoadTimeout);
|
||||
iframeLoaded = true;
|
||||
showIframe();
|
||||
if (iframe.contentWindow) {
|
||||
iframe.contentWindow.postMessage({ type: "PANEL_READY" }, "*");
|
||||
}
|
||||
} else if (event.data.type === CHROME_MESSAGE.AUTH_REQUIRED) {
|
||||
authRequired = true;
|
||||
}
|
||||
}
|
||||
|
||||
function showIframe() {
|
||||
iframe.style.opacity = "1";
|
||||
loadingScreen.style.opacity = "0";
|
||||
setTimeout(() => {
|
||||
loadingScreen.style.display = "none";
|
||||
}, 500);
|
||||
}
|
||||
|
||||
async function loadOnyxDomain() {
|
||||
const response = await chrome.runtime.sendMessage({
|
||||
action: ACTIONS.GET_CURRENT_ONYX_DOMAIN,
|
||||
});
|
||||
if (response && response[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]) {
|
||||
setIframeSrc(
|
||||
response[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN] + SIDE_PANEL_PATH,
|
||||
"",
|
||||
);
|
||||
} else {
|
||||
console.warn("Onyx domain not found, using default");
|
||||
const domain = await getOnyxDomain();
|
||||
setIframeSrc(domain + SIDE_PANEL_PATH, "");
|
||||
}
|
||||
}
|
||||
|
||||
chrome.runtime.onMessage.addListener((request, sender, sendResponse) => {
|
||||
if (request.action === ACTIONS.OPEN_ONYX_WITH_INPUT) {
|
||||
setIframeSrc(request.url, request.pageUrl);
|
||||
} else if (request.action === ACTIONS.UPDATE_PAGE_URL) {
|
||||
sendWebsiteToIframe(request.pageUrl);
|
||||
}
|
||||
});
|
||||
|
||||
window.addEventListener("message", handleMessage);
|
||||
|
||||
initializePanel();
|
||||
startIframeLoadTimeout();
|
||||
})();
|
||||
252
extensions/chrome/src/pages/popup.html
Normal file
252
extensions/chrome/src/pages/popup.html
Normal file
@@ -0,0 +1,252 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta http-equiv="Permissions-Policy" content="clipboard-write=(self)" />
|
||||
<title>Onyx</title>
|
||||
<link rel="stylesheet" href="../styles/shared.css" />
|
||||
<style>
|
||||
:root {
|
||||
--background-900: #0a0a0a;
|
||||
--background-800: #1a1a1a;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.1);
|
||||
--white-15: rgba(255, 255, 255, 0.15);
|
||||
--white-20: rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
width: 300px;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
.popup-container {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.popup-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid var(--white-10);
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.popup-icon {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 10px;
|
||||
background: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.popup-icon img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: contain;
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
.popup-title {
|
||||
margin: 0;
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
.menu-button-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.menu-button-text {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.menu-button-shortcut {
|
||||
font-size: 11px;
|
||||
color: var(--text-light-03);
|
||||
font-weight: 400;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 12px;
|
||||
padding: 4px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.setting-label {
|
||||
font-size: 14px;
|
||||
font-weight: 400;
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
.setting-divider {
|
||||
height: 1px;
|
||||
background: var(--white-10);
|
||||
margin: 0 4px;
|
||||
}
|
||||
|
||||
.menu-button {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border: none;
|
||||
padding: 12px;
|
||||
width: 100%;
|
||||
text-align: left;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
color: var(--text-light-05);
|
||||
font-weight: 400;
|
||||
transition: background 0.2s;
|
||||
border-radius: 12px;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
}
|
||||
|
||||
.menu-button:hover {
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.menu-button svg {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
stroke: var(--text-light-05);
|
||||
fill: none;
|
||||
stroke-width: 2;
|
||||
stroke-linecap: round;
|
||||
stroke-linejoin: round;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.toggle-switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 44px;
|
||||
height: 24px;
|
||||
}
|
||||
|
||||
.toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggle-slider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(255, 255, 255, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(255, 255, 255, 0.4);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
transform: translateX(20px);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="popup-container">
|
||||
<div class="popup-header">
|
||||
<div class="popup-icon">
|
||||
<img src="../../public/icon48.png" alt="Onyx" />
|
||||
</div>
|
||||
<h2 class="popup-title">Onyx</h2>
|
||||
</div>
|
||||
|
||||
<div class="settings-group">
|
||||
<div class="setting-row">
|
||||
<label class="setting-label" for="defaultNewTabToggle">
|
||||
Use Onyx as new tab page
|
||||
</label>
|
||||
<label class="toggle-switch">
|
||||
<input type="checkbox" id="defaultNewTabToggle" />
|
||||
<span class="toggle-slider"></span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="menu-button" id="openSidePanel">
|
||||
<div class="menu-button-content">
|
||||
<div class="menu-button-text">
|
||||
<svg viewBox="0 0 24 24">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect>
|
||||
<line x1="15" y1="3" x2="15" y2="21"></line>
|
||||
</svg>
|
||||
Open Onyx Panel
|
||||
</div>
|
||||
<span class="menu-button-shortcut">Ctrl+O</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button class="menu-button" id="openOptions">
|
||||
<div class="menu-button-text">
|
||||
<svg viewBox="0 0 24 24">
|
||||
<circle cx="12" cy="12" r="3"></circle>
|
||||
<path
|
||||
d="M19.4 15a1.65 1.65 0 0 0 .33 1.82l.06.06a2 2 0 0 1 0 2.83 2 2 0 0 1-2.83 0l-.06-.06a1.65 1.65 0 0 0-1.82-.33 1.65 1.65 0 0 0-1 1.51V21a2 2 0 0 1-2 2 2 2 0 0 1-2-2v-.09A1.65 1.65 0 0 0 9 19.4a1.65 1.65 0 0 0-1.82.33l-.06.06a2 2 0 0 1-2.83 0 2 2 0 0 1 0-2.83l.06-.06a1.65 1.65 0 0 0 .33-1.82 1.65 1.65 0 0 0-1.51-1H3a2 2 0 0 1-2-2 2 2 0 0 1 2-2h.09A1.65 1.65 0 0 0 4.6 9a1.65 1.65 0 0 0-.33-1.82l-.06-.06a2 2 0 0 1 0-2.83 2 2 0 0 1 2.83 0l.06.06a1.65 1.65 0 0 0 1.82.33H9a1.65 1.65 0 0 0 1-1.51V3a2 2 0 0 1 2-2 2 2 0 0 1 2 2v.09a1.65 1.65 0 0 0 1 1.51 1.65 1.65 0 0 0 1.82-.33l.06-.06a2 2 0 0 1 2.83 0 2 2 0 0 1 0 2.83l-.06.06a1.65 1.65 0 0 0-.33 1.82V9a1.65 1.65 0 0 0 1.51 1H21a2 2 0 0 1 2 2 2 2 0 0 1-2 2h-.09a1.65 1.65 0 0 0-1.51 1z"
|
||||
></path>
|
||||
</svg>
|
||||
Extension Settings
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<script type="module" src="popup.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
58
extensions/chrome/src/pages/popup.js
Normal file
58
extensions/chrome/src/pages/popup.js
Normal file
@@ -0,0 +1,58 @@
|
||||
import { CHROME_SPECIFIC_STORAGE_KEYS } from "../utils/constants.js";
|
||||
|
||||
document.addEventListener("DOMContentLoaded", async function () {
|
||||
const defaultNewTabToggle = document.getElementById("defaultNewTabToggle");
|
||||
const openSidePanelButton = document.getElementById("openSidePanel");
|
||||
const openOptionsButton = document.getElementById("openOptions");
|
||||
|
||||
async function loadSetting() {
|
||||
const result = await chrome.storage.local.get({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: false,
|
||||
});
|
||||
if (defaultNewTabToggle) {
|
||||
defaultNewTabToggle.checked =
|
||||
result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB];
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleSetting() {
|
||||
const currentValue = defaultNewTabToggle.checked;
|
||||
await chrome.storage.local.set({
|
||||
[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: currentValue,
|
||||
});
|
||||
}
|
||||
|
||||
async function openSidePanel() {
|
||||
try {
|
||||
const [tab] = await chrome.tabs.query({
|
||||
active: true,
|
||||
currentWindow: true,
|
||||
});
|
||||
if (tab && chrome.sidePanel) {
|
||||
await chrome.sidePanel.open({ tabId: tab.id });
|
||||
window.close();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error opening side panel:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function openOptions() {
|
||||
chrome.runtime.openOptionsPage();
|
||||
window.close();
|
||||
}
|
||||
|
||||
await loadSetting();
|
||||
|
||||
if (defaultNewTabToggle) {
|
||||
defaultNewTabToggle.addEventListener("change", toggleSetting);
|
||||
}
|
||||
|
||||
if (openSidePanelButton) {
|
||||
openSidePanelButton.addEventListener("click", openSidePanel);
|
||||
}
|
||||
|
||||
if (openOptionsButton) {
|
||||
openOptionsButton.addEventListener("click", openOptions);
|
||||
}
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user