mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 21:55:46 +00:00
Compare commits
42 Commits
test-tests
...
v2.11.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed46504a1a | ||
|
|
7a24b34516 | ||
|
|
7a7ffa9051 | ||
|
|
3053ab518c | ||
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b |
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
39
.vscode/launch.json
vendored
39
.vscode/launch.json
vendored
@@ -149,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -587,6 +605,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -341,6 +341,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
# NOTE: These strings must be formatted in the same way as the output of
|
||||
# DocumentAccess::to_acl.
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -28,8 +28,8 @@ of "minimum value clipping".
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
contents. If there are lots of updates, this may miss.
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_opensearch_filtered_access_control_list(
|
||||
access: DocumentAccess,
|
||||
) -> list[str]:
|
||||
"""Generates an access control list with PUBLIC_DOC_PAT removed.
|
||||
|
||||
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
|
||||
"""
|
||||
access_control_list = access.to_acl()
|
||||
access_control_list.discard(PUBLIC_DOC_PAT)
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -152,10 +166,9 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
chunk.access
|
||||
),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
@@ -578,8 +591,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
|
||||
generate_opensearch_filtered_access_control_list(
|
||||
update_request.access
|
||||
)
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
@@ -625,13 +640,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
# TODO(andrei): Add a param for whether to retrieve hidden docs.
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
@@ -646,6 +659,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
@@ -672,9 +687,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
@@ -688,6 +700,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
|
||||
@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
def serialize_datetime_fields_to_epoch_seconds(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
Serializes datetime fields to seconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
return int(value.timestamp())
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses seconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
@@ -354,11 +353,9 @@ class DocumentSchema:
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
"format": "epoch_second",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
@@ -366,14 +363,21 @@ class DocumentSchema:
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
# should have no effect on queries.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
# If a user's access set contains at least one entry from this
|
||||
# set, the user should be able to retrieve this document. This
|
||||
# only applies if public is set to false; public non-hidden
|
||||
# documents are always visible to anyone in a given tenancy
|
||||
# regardless of this field.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
# Whether the doc is hidden from search results.
|
||||
# Should clobber all other access search filters, namely
|
||||
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
|
||||
# search implementations to guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
@@ -447,7 +451,6 @@ class DocumentSchema:
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
@@ -91,6 +106,11 @@ assert (
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
@@ -103,6 +123,8 @@ class DocumentQuery:
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -120,6 +142,8 @@ class DocumentQuery:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the document retrieval query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
@@ -136,28 +160,21 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
max_chunk_size=max_chunk_size,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
@@ -195,15 +212,22 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
# Delete hidden docs too.
|
||||
include_hidden=True,
|
||||
access_control_list=None,
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=document_id,
|
||||
)
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
@@ -217,19 +241,25 @@ class DocumentQuery:
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
NOTE: This query can be directly supplied to the OpenSearch client, but
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
num_candidates: The number of neighbors to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the hybrid search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
@@ -243,31 +273,47 @@ class DocumentQuery:
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
@@ -294,7 +340,8 @@ class DocumentQuery:
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -305,6 +352,7 @@ class DocumentQuery:
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the title.
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -313,6 +361,7 @@ class DocumentQuery:
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the content.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -322,36 +371,273 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
# Either fuzzy match on the analyzed title (boosted 2x), or
|
||||
# exact match on exact title keywords (no OpenSearch
|
||||
# analysis done on the title). See
|
||||
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
# Returns the score of the best match of the fields above.
|
||||
# See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
# Fuzzy match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
# Exact match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
include_hidden: bool,
|
||||
access_control_list: list[str] | None,
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
max_chunk_size: int | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns filters to be passed into the "filter" key of a search query.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
access_control_list: Access control list for the documents to
|
||||
retrieve. If None, there is no restriction on the documents that
|
||||
can be retrieved. If not None, only public documents can be
|
||||
retrieved, or non-public documents where at least one acl
|
||||
provided here is present in the document's acl list.
|
||||
source_types: If supplied, only documents of one of these source
|
||||
types will be retrieved.
|
||||
tags: If supplied, only documents with an entry in their metadata
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
updated time, we assume some default age of
|
||||
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
|
||||
updated.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
max_chunk_size: The type of chunk to retrieve, specified by the
|
||||
maximum number of tokens it can hold. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
)
|
||||
return tag_filter
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
|
||||
}
|
||||
}
|
||||
)
|
||||
if time_cutoff < datetime.now(timezone.utc) - timedelta(
|
||||
days=ASSUMED_DOCUMENT_AGE_DAYS
|
||||
):
|
||||
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
|
||||
# ago, we include documents which have no
|
||||
# LAST_UPDATED_FIELD_NAME value.
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
|
||||
}
|
||||
}
|
||||
)
|
||||
return time_cutoff_filter
|
||||
|
||||
def _get_chunk_index_filter(
|
||||
min_chunk_index: int | None, max_chunk_index: int | None
|
||||
) -> dict[str, Any]:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
return range_clause
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
|
||||
|
||||
if access_control_list is not None:
|
||||
# If an access control list is provided, the caller can only
|
||||
# retrieve public documents, and non-public documents where at least
|
||||
# one acl provided here is present in the document's acl list. If
|
||||
# there is explicitly no list provided, we make no restrictions on
|
||||
# the documents that can be retrieved.
|
||||
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
|
||||
|
||||
if source_types:
|
||||
# If at least one source type is provided, the caller will only
|
||||
# retrieve documents whose source type is present in this input
|
||||
# list.
|
||||
filter_clauses.append(_get_source_type_filter(source_types))
|
||||
|
||||
if tags:
|
||||
# If at least one tag is provided, the caller will only retrieve
|
||||
# documents where at least one tag provided here is present in the
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
if document_sets:
|
||||
# If at least one document set is provided, the caller will only
|
||||
# retrieve documents where at least one document set provided here
|
||||
# is present in the document's document sets list.
|
||||
filter_clauses.append(_get_document_set_filter(document_sets))
|
||||
|
||||
if user_file_ids:
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs. Note that these IDs correspond to Onyx documents whereas
|
||||
# the entries retrieved from the document index correspond to Onyx
|
||||
# document chunks.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
# cutoff. For documents which do not have a value for
|
||||
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
|
||||
# purposes of time cutoff.
|
||||
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
filter_clauses.append(
|
||||
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
if max_chunk_size is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
return filter_clauses
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
@@ -378,4 +664,5 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,12 +1,149 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -32,6 +32,7 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -45,6 +45,7 @@ class RedisConnectorPrune:
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
|
||||
@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -84,7 +84,8 @@ def patch_document_set(
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
@@ -125,7 +126,8 @@ def delete_document_set(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
object_is_public=document_set.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
@router.get("/", tags=PUBLIC_API_TAGS)
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
|
||||
|
||||
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
|
||||
async def get_auth_type() -> AuthTypeResponse:
|
||||
user_count = await get_user_count()
|
||||
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
|
||||
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
|
||||
# associated with until they auth.
|
||||
has_users = True
|
||||
if AUTH_TYPE != AuthType.CLOUD:
|
||||
user_count = await get_user_count()
|
||||
has_users = user_count > 0
|
||||
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
password_min_length=PASSWORD_MIN_LENGTH,
|
||||
has_users=user_count > 0,
|
||||
has_users=has_users,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
|
||||
# Determine stop reason - check if message indicates user cancelled
|
||||
stop_reason: str | None = None
|
||||
if chat_message.message:
|
||||
if "Generation was stopped" in chat_message.message:
|
||||
if "generation was stopped" in chat_message.message.lower():
|
||||
stop_reason = "user_cancelled"
|
||||
|
||||
# Add overall stop packet at the end
|
||||
|
||||
@@ -573,7 +573,7 @@ mcp==1.25.0
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==0.8.4
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
|
||||
@@ -298,7 +298,7 @@ numpy==2.4.1
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.4.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -191,6 +191,18 @@ autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Listens for Discord messages and responds with answers
|
||||
# for all guilds/channels that the OnyxBot has been added to.
|
||||
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
|
||||
[program:discord_bot]
|
||||
command=python onyx/onyxbot/discord/client.py
|
||||
stdout_logfile=/var/log/discord_bot.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Pushes all logs from the above programs to stdout
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
@@ -206,6 +218,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_user_file_processing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/discord_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
/var/log/mcp_server.log
|
||||
/var/log/mcp_server.err.log
|
||||
|
||||
@@ -8,14 +8,22 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
generate_opensearch_filtered_access_control_list,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
@@ -42,14 +50,22 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
|
||||
|
||||
def _create_test_document_chunk(
|
||||
document_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
tenant_state: TenantState,
|
||||
chunk_index: int = 0,
|
||||
content_vector: list[float] | None = None,
|
||||
title: str | None = None,
|
||||
title_vector: list[float] | None = None,
|
||||
public: bool = True,
|
||||
hidden: bool = False,
|
||||
document_access: DocumentAccess = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
),
|
||||
source_type: DocumentSource = DocumentSource.FILE,
|
||||
last_updated: datetime | None = None,
|
||||
) -> DocumentChunk:
|
||||
if content_vector is None:
|
||||
# Generate dummy vector - 128 dimensions for fast testing.
|
||||
@@ -59,11 +75,6 @@ def _create_test_document_chunk(
|
||||
if title is not None and title_vector is None:
|
||||
title_vector = [0.2] * 128
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# We only store millisecond precision, so to make sure asserts work in this
|
||||
# test file manually lose some precision from datetime.now().
|
||||
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
|
||||
|
||||
return DocumentChunk(
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
@@ -71,11 +82,13 @@ def _create_test_document_chunk(
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type="test_source",
|
||||
source_type=source_type.value,
|
||||
metadata_list=None,
|
||||
last_updated=now,
|
||||
public=public,
|
||||
access_control_list=[],
|
||||
last_updated=last_updated,
|
||||
public=document_access.is_public,
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
document_access
|
||||
),
|
||||
hidden=hidden,
|
||||
global_boost=0,
|
||||
semantic_identifier="Test semantic identifier",
|
||||
@@ -331,6 +344,9 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Content to retrieve",
|
||||
tenant_state=tenant_state,
|
||||
# We only store second precision, so to make sure asserts work in
|
||||
# this test we'll deliberately lose some precision.
|
||||
last_updated=datetime.now(timezone.utc).replace(microsecond=0),
|
||||
)
|
||||
test_client.index_document(document=original_doc)
|
||||
|
||||
@@ -471,6 +487,8 @@ class TestOpenSearchClient:
|
||||
search_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="delete-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -483,6 +501,8 @@ class TestOpenSearchClient:
|
||||
keep_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="keep-me",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -510,7 +530,6 @@ class TestOpenSearchClient:
|
||||
chunk_index=0,
|
||||
content="Original content",
|
||||
tenant_state=tenant_state,
|
||||
public=True,
|
||||
hidden=False,
|
||||
)
|
||||
test_client.index_document(document=doc)
|
||||
@@ -561,10 +580,13 @@ class TestOpenSearchClient:
|
||||
properties_to_update={"hidden": True},
|
||||
)
|
||||
|
||||
def test_search_basic(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests basic search functionality."""
|
||||
"""Tests hybrid search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
@@ -574,24 +596,24 @@ class TestOpenSearchClient:
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index multiple documents with different content and vectors.
|
||||
# Index documents.
|
||||
docs = {
|
||||
"search-doc-1": _create_test_document_chunk(
|
||||
document_id="search-doc-1",
|
||||
"doc-1": _create_test_document_chunk(
|
||||
document_id="doc-1",
|
||||
chunk_index=0,
|
||||
content="Python programming language tutorial",
|
||||
content_vector=_generate_test_vector(0.1),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-2": _create_test_document_chunk(
|
||||
document_id="search-doc-2",
|
||||
"doc-2": _create_test_document_chunk(
|
||||
document_id="doc-2",
|
||||
chunk_index=0,
|
||||
content="How to make cheese",
|
||||
content_vector=_generate_test_vector(0.2),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"search-doc-3": _create_test_document_chunk(
|
||||
document_id="search-doc-3",
|
||||
"doc-3": _create_test_document_chunk(
|
||||
document_id="doc-3",
|
||||
chunk_index=0,
|
||||
content="C++ for newborns",
|
||||
content_vector=_generate_test_vector(0.15),
|
||||
@@ -613,78 +635,10 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 3
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id
|
||||
in ["search-doc-1", "search-doc-2", "search-doc-3"]
|
||||
for chunk in results
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
|
||||
# Make sure there is some kind of match highlight for the first hit. We
|
||||
# don't expect highlights for any other hit.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents.
|
||||
docs = {
|
||||
"pipeline-doc-1": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-1",
|
||||
chunk_index=0,
|
||||
content="Machine learning algorithms for single-celled organisms",
|
||||
content_vector=_generate_test_vector(0.3),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"pipeline-doc-2": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-2",
|
||||
chunk_index=0,
|
||||
content="Deep learning shallow neural networks",
|
||||
content_vector=_generate_test_vector(0.35),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query.
|
||||
query_text = "machine learning"
|
||||
query_vector = _generate_test_vector(0.32)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -693,23 +647,26 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 2
|
||||
assert len(results) == len(docs)
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
|
||||
for chunk in results
|
||||
)
|
||||
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
for i, chunk in enumerate(results):
|
||||
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert chunk.score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Make sure there is some kind of match highlight only for the first
|
||||
# result. The other results are so bad they're not expected to have
|
||||
# match highlights.
|
||||
if i == 0:
|
||||
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_empty_index(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search on an empty index returns an empty list."""
|
||||
# Precondition.
|
||||
@@ -731,19 +688,28 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
# We're not worried about filtering here. tenant_id in this object
|
||||
# is not relevant.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_filters(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests search filters for public/hidden documents and tenant isolation.
|
||||
Tests search filters for ACL, hidden documents, and tenant isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
@@ -757,29 +723,47 @@ class TestOpenSearchClient:
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc-1": _create_test_document_chunk(
|
||||
document_id="public-doc-1",
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc-1": _create_test_document_chunk(
|
||||
document_id="hidden-doc-1",
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-1": _create_test_document_chunk(
|
||||
document_id="private-doc-1",
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
public=False,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
@@ -787,7 +771,6 @@ class TestOpenSearchClient:
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
@@ -798,9 +781,6 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search with default filters (public=True, hidden=False).
|
||||
# The DocumentQuery.get_hybrid_search_query uses filters that should
|
||||
# only return public, non-hidden documents.
|
||||
query_text = "document content"
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -809,24 +789,41 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
results = test_client.search(
|
||||
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document.
|
||||
assert len(results) == 1
|
||||
assert results[0].document_chunk.document_id == "public-doc-1"
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# NOTE: This test is not explicitly testing for how well results are
|
||||
# ordered; we're just assuming which doc will be the first result here.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == docs["public-doc-1"]
|
||||
assert results[0].document_chunk == docs["public-doc"]
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == docs["private-doc-user-a"]
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
@@ -849,52 +846,54 @@ class TestOpenSearchClient:
|
||||
# Vectors closer to query_vector (0.1) should rank higher.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="highly-relevant-1",
|
||||
document_id="highly-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence and machine learning transform technology",
|
||||
content_vector=_generate_test_vector(
|
||||
0.1
|
||||
), # Very close to query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="somewhat-relevant-1",
|
||||
document_id="somewhat-relevant",
|
||||
chunk_index=0,
|
||||
content="Computer programming with various languages",
|
||||
content_vector=_generate_test_vector(0.5), # Far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="not-very-relevant-1",
|
||||
document_id="not-very-relevant",
|
||||
chunk_index=0,
|
||||
content="Cooking recipes for delicious meals",
|
||||
content_vector=_generate_test_vector(
|
||||
0.9
|
||||
), # Very far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
# These should be filtered out by public/hidden filters.
|
||||
_create_test_document_chunk(
|
||||
document_id="hidden-but-relevant-1",
|
||||
document_id="hidden-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence research papers",
|
||||
content_vector=_generate_test_vector(0.05), # Very close but hidden.
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="private-but-relevant-1",
|
||||
document_id="private-but-relevant",
|
||||
chunk_index=0,
|
||||
content="Artificial intelligence industry analysis",
|
||||
content_vector=_generate_test_vector(0.08), # Very close but private.
|
||||
public=False,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
@@ -905,7 +904,7 @@ class TestOpenSearchClient:
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Search query matching "highly-relevant-1" most closely.
|
||||
# Search query matching "highly-relevant" most closely.
|
||||
query_text = "artificial intelligence"
|
||||
query_vector = _generate_test_vector(0.1)
|
||||
search_body = DocumentQuery.get_hybrid_search_query(
|
||||
@@ -914,6 +913,9 @@ class TestOpenSearchClient:
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# Explicitly pass in an empty list to enforce private doc filtering.
|
||||
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -925,15 +927,15 @@ class TestOpenSearchClient:
|
||||
# Should only get public, non-hidden documents (3 out of 5).
|
||||
assert len(results) == 3
|
||||
result_ids = [chunk.document_chunk.document_id for chunk in results]
|
||||
assert "highly-relevant-1" in result_ids
|
||||
assert "somewhat-relevant-1" in result_ids
|
||||
assert "not-very-relevant-1" in result_ids
|
||||
assert "highly-relevant" in result_ids
|
||||
assert "somewhat-relevant" in result_ids
|
||||
assert "not-very-relevant" in result_ids
|
||||
# Filtered out by public/hidden constraints.
|
||||
assert "hidden-but-relevant-1" not in result_ids
|
||||
assert "private-but-relevant-1" not in result_ids
|
||||
assert "hidden-but-relevant" not in result_ids
|
||||
assert "private-but-relevant" not in result_ids
|
||||
|
||||
# Most relevant document should be first due to normalization pipeline.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant-1"
|
||||
# Most relevant document should be first.
|
||||
assert results[0].document_chunk.document_id == "highly-relevant"
|
||||
|
||||
# Make sure there is some kind of match highlight for the most relevant
|
||||
# result.
|
||||
@@ -1014,6 +1016,8 @@ class TestOpenSearchClient:
|
||||
verify_query_x = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1026,6 +1030,8 @@ class TestOpenSearchClient:
|
||||
verify_query_y = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-y",
|
||||
tenant_state=tenant_y,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1113,6 +1119,8 @@ class TestOpenSearchClient:
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-1",
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -1133,3 +1141,176 @@ class TestOpenSearchClient:
|
||||
for chunk in doc1_chunks
|
||||
}
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
private ones.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_state,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for all documents.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="private-doc-user-a",
|
||||
tenant_state=tenant_state,
|
||||
# This is the input under test, notice None for acl.
|
||||
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
# Even though this doc is private, because we supplied None for acl we
|
||||
# were able to retrieve it.
|
||||
assert len(chunk_ids) == 1
|
||||
# Since this is a chunk ID, it will have the doc ID in it plus other
|
||||
# stuff we don't care about in this test.
|
||||
assert chunk_ids[0].startswith("private-doc-user-a")
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests the time cutoff filter works."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index docs with various ages.
|
||||
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
one_week_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
six_months_ago = datetime.now(timezone.utc) - timedelta(days=180)
|
||||
one_year_ago = datetime.now(timezone.utc) - timedelta(days=365)
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
document_id="one-day-ago",
|
||||
content="Good match",
|
||||
last_updated=one_day_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="one-year-ago",
|
||||
content="Good match",
|
||||
last_updated=one_year_ago,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="no-last-updated",
|
||||
# Since we test for result ordering in the postconditions, let's
|
||||
# just make this content slightly less of a match with the query
|
||||
# so this test is not flaky from the ordering of the results.
|
||||
content="Still an ok match",
|
||||
last_updated=None,
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for documents updated in the last week.
|
||||
last_week_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=one_week_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
last_six_months_search_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text="Good match",
|
||||
query_vector=_generate_test_vector(0.1),
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_state,
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=None, tenant_id=None, time_cutoff=six_months_ago
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
last_week_results = test_client.search(
|
||||
body=last_week_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
last_six_months_results = test_client.search(
|
||||
body=last_six_months_search_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# We expect to only get one-day-ago.
|
||||
assert len(last_week_results) == 1
|
||||
assert last_week_results[0].document_chunk.document_id == "one-day-ago"
|
||||
# We expect to get one-day-ago and no-last-updated since six months >
|
||||
# ASSUMED_DOCUMENT_AGE_DAYS.
|
||||
assert len(last_six_months_results) == 2
|
||||
assert last_six_months_results[0].document_chunk.document_id == "one-day-ago"
|
||||
assert (
|
||||
last_six_months_results[1].document_chunk.document_id == "no-last-updated"
|
||||
)
|
||||
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProjectManager:
|
||||
) -> List[UserProjectSnapshot]:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -56,7 +56,7 @@ class ProjectManager:
|
||||
) -> bool:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
|
||||
provider_id = _activate_exa_provider(admin_user)
|
||||
assert isinstance(provider_id, int)
|
||||
|
||||
search_request = {"queries": ["latest ai research news"], "max_results": 3}
|
||||
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
|
||||
|
||||
lite_response = requests.post(
|
||||
f"{API_SERVER_URL}/web-search/search-lite",
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
|
||||
class TestConvertToMetadataValue:
|
||||
"""Tests for the _convert_to_metadata_value helper function."""
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""String values should be returned as-is."""
|
||||
assert _convert_to_metadata_value("hello") == "hello"
|
||||
assert _convert_to_metadata_value("") == ""
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
"""Boolean True should be converted to string 'True'."""
|
||||
assert _convert_to_metadata_value(True) == "True"
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
"""Boolean False should be converted to string 'False'."""
|
||||
assert _convert_to_metadata_value(False) == "False"
|
||||
|
||||
def test_integer_value(self) -> None:
|
||||
"""Integer values should be converted to string."""
|
||||
assert _convert_to_metadata_value(42) == "42"
|
||||
assert _convert_to_metadata_value(0) == "0"
|
||||
assert _convert_to_metadata_value(-100) == "-100"
|
||||
|
||||
def test_float_value(self) -> None:
|
||||
"""Float values should be converted to string."""
|
||||
assert _convert_to_metadata_value(3.14) == "3.14"
|
||||
assert _convert_to_metadata_value(0.0) == "0.0"
|
||||
assert _convert_to_metadata_value(-2.5) == "-2.5"
|
||||
|
||||
def test_list_of_strings(self) -> None:
|
||||
"""List of strings should remain as list of strings."""
|
||||
result = _convert_to_metadata_value(["a", "b", "c"])
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_list_of_mixed_types(self) -> None:
|
||||
"""List with mixed types should have all items converted to strings."""
|
||||
result = _convert_to_metadata_value([1, True, 3.14, "text"])
|
||||
assert result == ["1", "True", "3.14", "text"]
|
||||
|
||||
def test_empty_list(self) -> None:
|
||||
"""Empty list should return empty list."""
|
||||
assert _convert_to_metadata_value([]) == []
|
||||
|
||||
|
||||
class TestYieldDocBatches:
|
||||
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
|
||||
|
||||
@pytest.fixture
|
||||
def connector(self) -> SalesforceConnector:
|
||||
"""Create a SalesforceConnector instance with mocked sf_client."""
|
||||
connector = SalesforceConnector(
|
||||
batch_size=10,
|
||||
requested_objects=["Opportunity"],
|
||||
)
|
||||
# Mock the sf_client property
|
||||
mock_sf_client = MagicMock()
|
||||
mock_sf_client.sf_instance = "test.salesforce.com"
|
||||
connector._sf_client = mock_sf_client
|
||||
return connector
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sf_db(self) -> MagicMock:
|
||||
"""Create a mock OnyxSalesforceSQLite object."""
|
||||
return MagicMock()
|
||||
|
||||
def _create_salesforce_object(
|
||||
self,
|
||||
object_id: str,
|
||||
object_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> SalesforceObject:
|
||||
"""Helper to create a SalesforceObject with required fields."""
|
||||
# Ensure required fields are present
|
||||
data.setdefault(ID_FIELD, object_id)
|
||||
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
|
||||
data.setdefault(NAME_FIELD, f"Test {object_type}")
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_metadata_type_conversion_for_opportunity(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that Opportunity metadata fields are properly type-converted."""
|
||||
parent_id = "006bm000006kyDpAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create a parent object with various data types in the fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Test Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"Account": "Acme Corp", # string - should become "account" metadata
|
||||
"FiscalQuarter": 2, # int - should be converted to "2"
|
||||
"FiscalYear": 2024, # int - should be converted to "2024"
|
||||
"IsClosed": False, # bool - should be converted to "False"
|
||||
"StageName": "Prospecting", # string
|
||||
"Type": "New Business", # string
|
||||
"Amount": 50000.50, # float - should be converted to "50000.50"
|
||||
"CloseDate": "2024-06-30", # string
|
||||
"Probability": 75, # int - should be converted to "75"
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
# Setup mock sf_db
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create a mock document that convert_sf_object_to_doc will return
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Test Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
# Track parent changes
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# Call _yield_doc_batches
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify we got one batch with one document
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify metadata type conversions
|
||||
# All values should be strings (or list of strings)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["account"] == "Acme Corp" # string stays string
|
||||
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
|
||||
assert doc.metadata["fiscal_year"] == "2024" # int -> str
|
||||
assert doc.metadata["is_closed"] == "False" # bool -> str
|
||||
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
|
||||
assert doc.metadata["type"] == "New Business" # string stays string
|
||||
assert (
|
||||
doc.metadata["amount"] == "50000.5"
|
||||
) # float -> str (Python drops trailing zeros)
|
||||
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
|
||||
assert doc.metadata["probability"] == "75" # int -> str
|
||||
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
|
||||
|
||||
# Verify parent was counted
|
||||
assert parents_changed == 1
|
||||
assert type_to_processed[parent_type] == 1
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_missing_optional_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing optional metadata fields are not added."""
|
||||
parent_id = "006bm000006kyDqAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create parent object with only some fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Minimal Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"StageName": "Closed Won",
|
||||
# Notably missing: Amount, Probability, FiscalQuarter, etc.
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Minimal Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only present fields should be in metadata
|
||||
assert "stage_name" in doc.metadata
|
||||
assert doc.metadata["stage_name"] == "Closed Won"
|
||||
assert "name" in doc.metadata
|
||||
assert doc.metadata["name"] == "Minimal Opportunity"
|
||||
|
||||
# Missing fields should not be in metadata
|
||||
assert "amount" not in doc.metadata
|
||||
assert "probability" not in doc.metadata
|
||||
assert "fiscal_quarter" not in doc.metadata
|
||||
assert "fiscal_year" not in doc.metadata
|
||||
assert "is_closed" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_contact_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test metadata conversion for Contact object type."""
|
||||
parent_id = "003bm00000EjHCjAAN"
|
||||
parent_type = "Contact"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "John Doe",
|
||||
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
|
||||
"Account": "Globex Corp",
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z",
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="John Doe",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify Contact-specific metadata
|
||||
assert doc.metadata["object_type"] == "Contact"
|
||||
assert doc.metadata["account"] == "Globex Corp"
|
||||
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
|
||||
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_no_default_attributes_for_unknown_type(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that unknown object types only get object_type metadata."""
|
||||
parent_id = "001bm00000fd9Z3AAI"
|
||||
parent_type = "CustomObject__c"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Custom Record",
|
||||
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
|
||||
"CustomField__c": "custom value",
|
||||
"NumberField__c": 123,
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Custom Record",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only object_type should be set for unknown types
|
||||
assert doc.metadata["object_type"] == "CustomObject__c"
|
||||
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
|
||||
assert "CustomField__c" not in doc.metadata
|
||||
assert "NumberField__c" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_skips_missing_parent_objects(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing parent objects are skipped gracefully."""
|
||||
parent_id = "006bm000006kyDrAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# get_record returns None for missing object
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = None
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Should yield one empty batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 0
|
||||
|
||||
# convert_sf_object_to_doc should not have been called
|
||||
mock_convert.assert_not_called()
|
||||
|
||||
# Parents changed should still be 0
|
||||
assert parents_changed == 0
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_multiple_documents_batching(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that multiple documents are correctly batched."""
|
||||
# Create 3 parent objects
|
||||
parent_ids = [
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
]
|
||||
parent_type = "Opportunity"
|
||||
|
||||
parent_objects = [
|
||||
self._create_salesforce_object(
|
||||
pid,
|
||||
parent_type,
|
||||
{
|
||||
ID_FIELD: pid,
|
||||
NAME_FIELD: f"Opportunity {i}",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"IsClosed": i % 2 == 0, # alternating bool values
|
||||
"Amount": 1000.0 * (i + 1),
|
||||
},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
|
||||
# Setup mock to return all three
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
|
||||
)
|
||||
mock_sf_db.get_record.side_effect = parent_objects
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create mock documents
|
||||
mock_docs = [
|
||||
Document(
|
||||
id=f"SALESFORCE_{pid}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=f"Opportunity {i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
mock_convert.side_effect = mock_docs
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
# With batch_size=10, all 3 docs should be in one batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 3
|
||||
|
||||
# Verify each document has correct metadata
|
||||
for i, doc in enumerate(batches[0]):
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["is_closed"] == str(i % 2 == 0)
|
||||
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
|
||||
|
||||
assert type_to_processed[parent_type] == 3
|
||||
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
|
||||
|
||||
def _mock_user(
|
||||
user_id: UUID | None = None, email: str = "test@example.com"
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.oauth_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
def _make_query_chain() -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
return chain
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_nulls_out_document_set_ownership(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Verify DocumentSet.user_id is nulled out (update, not delete)
|
||||
doc_set_chain = query_chains[DocumentSet]
|
||||
doc_set_chain.filter.assert_called()
|
||||
doc_set_chain.filter.return_value.update.assert_called_once_with(
|
||||
{DocumentSet.user_id: None}
|
||||
)
|
||||
|
||||
# Verify Persona.user_id is nulled out (update, not delete)
|
||||
persona_chain = query_chains[Persona]
|
||||
persona_chain.filter.assert_called()
|
||||
persona_chain.filter.return_value.update.assert_called_once_with(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_cleans_up_join_tables(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Join tables should be deleted (not updated)
|
||||
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
|
||||
chain = query_chains[model]
|
||||
chain.filter.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_commits_and_removes_invited(
|
||||
_mock_ee: Any, mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user(email="deleted@example.com")
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_called_once_with(user)
|
||||
db_session.commit.assert_called_once()
|
||||
mock_remove_invited.assert_called_once_with("deleted@example.com")
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_deletes_oauth_accounts(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
oauth1 = MagicMock()
|
||||
oauth2 = MagicMock()
|
||||
user.oauth_accounts = [oauth1, oauth2]
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_any_call(oauth1)
|
||||
db_session.delete.assert_any_call(oauth2)
|
||||
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
@@ -221,6 +221,13 @@ services:
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
|
||||
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
|
||||
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
|
||||
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
|
||||
@@ -63,6 +63,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -82,6 +82,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -129,6 +129,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
|
||||
## CORS origins for MCP clients (comma-separated list)
|
||||
# MCP_SERVER_CORS_ORIGINS=
|
||||
|
||||
## Discord Bot Configuration
|
||||
## The Discord bot allows users to interact with Onyx from Discord servers
|
||||
## Bot token from Discord Developer Portal (required to enable the bot)
|
||||
# DISCORD_BOT_TOKEN=
|
||||
## Command prefix for bot commands (default: "!")
|
||||
# DISCORD_BOT_INVOKE_CHAR=!
|
||||
|
||||
## Celery Configuration
|
||||
# CELERY_BROKER_POOL_LIMIT=
|
||||
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=
|
||||
|
||||
@@ -582,29 +582,33 @@ else
|
||||
fi
|
||||
|
||||
# Ask for authentication schema
|
||||
echo ""
|
||||
print_info "Which authentication schema would you like to set up?"
|
||||
echo ""
|
||||
echo "1) Basic - Username/password authentication"
|
||||
echo "2) No Auth - Open access (development/testing)"
|
||||
echo ""
|
||||
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
|
||||
echo ""
|
||||
# echo ""
|
||||
# print_info "Which authentication schema would you like to set up?"
|
||||
# echo ""
|
||||
# echo "1) Basic - Username/password authentication"
|
||||
# echo "2) No Auth - Open access (development/testing)"
|
||||
# echo ""
|
||||
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
|
||||
# echo ""
|
||||
|
||||
case "${AUTH_CHOICE:-1}" in
|
||||
1)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Selected: Basic authentication"
|
||||
;;
|
||||
2)
|
||||
AUTH_SCHEMA="disabled"
|
||||
print_info "Selected: No authentication"
|
||||
;;
|
||||
*)
|
||||
AUTH_SCHEMA="basic"
|
||||
print_info "Invalid choice, using basic authentication"
|
||||
;;
|
||||
esac
|
||||
# case "${AUTH_CHOICE:-1}" in
|
||||
# 1)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Selected: Basic authentication"
|
||||
# ;;
|
||||
# # 2)
|
||||
# # AUTH_SCHEMA="disabled"
|
||||
# # print_info "Selected: No authentication"
|
||||
# # ;;
|
||||
# *)
|
||||
# AUTH_SCHEMA="basic"
|
||||
# print_info "Invalid choice, using basic authentication"
|
||||
# ;;
|
||||
# esac
|
||||
|
||||
# TODO (jessica): Uncomment this once no auth users still have an account
|
||||
# Use basic auth by default
|
||||
AUTH_SCHEMA="basic"
|
||||
|
||||
# Create .env file from template
|
||||
print_info "Creating .env file with your selections..."
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.19
|
||||
version: 0.4.20
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
{{- if .Values.discordbot.enabled }}
|
||||
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
|
||||
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-discordbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
|
||||
{{- with .Values.discordbot.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 8 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.discordbot.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: discordbot
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
|
||||
imagePullPolicy: {{ .Values.global.pullPolicy }}
|
||||
command: ["python", "onyx/onyxbot/discord/client.py"]
|
||||
resources:
|
||||
{{- toYaml .Values.discordbot.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx.envSecrets" . | nindent 12}}
|
||||
# Discord bot token - required for bot to connect
|
||||
{{- if .Values.discordbot.botToken }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
value: {{ .Values.discordbot.botToken | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.discordbot.botTokenSecretName }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.discordbot.botTokenSecretName }}
|
||||
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
|
||||
{{- end }}
|
||||
# Command prefix for bot commands (default: "!")
|
||||
{{- if .Values.discordbot.invokeChar }}
|
||||
- name: DISCORD_BOT_INVOKE_CHAR
|
||||
value: {{ .Values.discordbot.invokeChar | quote }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -701,6 +701,44 @@ celery_worker_user_file_processing:
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
# Discord bot for Onyx
|
||||
# The bot offloads message processing to scalable API pods via HTTP requests.
|
||||
discordbot:
|
||||
enabled: false # Disabled by default - requires bot token configuration
|
||||
# Bot token can be provided directly or via a Kubernetes secret
|
||||
# Option 1: Direct token (not recommended for production)
|
||||
botToken: ""
|
||||
# Option 2: Reference a Kubernetes secret (recommended)
|
||||
botTokenSecretName: "" # Name of the secret containing the bot token
|
||||
botTokenSecretKey: "token" # Key within the secret (default: "token")
|
||||
# Command prefix for bot commands (default: "!")
|
||||
invokeChar: "!"
|
||||
image:
|
||||
repository: onyxdotapp/onyx-backend
|
||||
tag: "" # Overrides the image tag whose default is the chart appVersion.
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend
|
||||
app: discord-bot
|
||||
deploymentLabels:
|
||||
app: discord-bot
|
||||
podSecurityContext:
|
||||
{}
|
||||
securityContext:
|
||||
{}
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2000Mi"
|
||||
volumes: []
|
||||
volumeMounts: []
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
slackbot:
|
||||
enabled: true
|
||||
replicaCount: 1
|
||||
@@ -1159,6 +1197,8 @@ configMap:
|
||||
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
|
||||
NOTIFY_SLACKBOT_NO_ANSWER: ""
|
||||
DISCORD_BOT_TOKEN: ""
|
||||
DISCORD_BOT_INVOKE_CHAR: ""
|
||||
# Logging
|
||||
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
|
||||
DISABLE_TELEMETRY: ""
|
||||
|
||||
3
desktop/.gitignore
vendored
3
desktop/.gitignore
vendored
@@ -22,3 +22,6 @@ npm-debug.log*
|
||||
# Local env files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Generated files
|
||||
src-tauri/gen/schemas/acl-manifests.json
|
||||
|
||||
96
desktop/src-tauri/Cargo.lock
generated
96
desktop/src-tauri/Cargo.lock
generated
@@ -706,16 +706,6 @@ dependencies = [
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -993,16 +983,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -1122,24 +1102,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "global-hotkey"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"keyboard-types",
|
||||
"objc2 0.6.3",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"windows-sys 0.59.0",
|
||||
"x11rb",
|
||||
"xkeysym",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -1713,12 +1675,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
@@ -2248,7 +2204,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-global-shortcut",
|
||||
"tauri-plugin-shell",
|
||||
"tauri-plugin-window-state",
|
||||
"tokio",
|
||||
@@ -2878,19 +2833,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -3605,21 +3547,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-global-shortcut"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
|
||||
dependencies = [
|
||||
"global-hotkey",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-shell"
|
||||
version = "2.3.3"
|
||||
@@ -5021,29 +4948,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xkeysym"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-global-shortcut = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -20,7 +20,6 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shortcuts Setup
|
||||
// ============================================================================
|
||||
|
||||
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
|
||||
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
|
||||
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
// Avoid hijacking the system-wide Cmd+R on macOS.
|
||||
#[cfg(target_os = "macos")]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
reload,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
shortcuts,
|
||||
move |_app, shortcut, _event| {
|
||||
if shortcut == &new_chat {
|
||||
trigger_new_chat(&app_handle);
|
||||
}
|
||||
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if shortcut == &reload {
|
||||
let _ = window.eval("window.location.reload()");
|
||||
} else if shortcut == &back {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if shortcut == &new_window_shortcut {
|
||||
trigger_new_window(&app_handle);
|
||||
} else if shortcut == &show_app {
|
||||
focus_main_window(&app_handle);
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+Space"),
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
@@ -666,7 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -698,11 +629,6 @@ fn main() {
|
||||
.setup(move |app| {
|
||||
let app_handle = app.handle();
|
||||
|
||||
// Setup global shortcuts
|
||||
if let Err(e) = setup_shortcuts(&app_handle) {
|
||||
eprintln!("Failed to setup shortcuts: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = setup_app_menu(&app_handle) {
|
||||
eprintln!("Failed to setup menu: {}", e);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background-900: #1a1a1a;
|
||||
--background-800: #262626;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.08);
|
||||
--white-15: rgba(255, 255, 255, 0.12);
|
||||
--white-20: rgba(255, 255, 255, 0.15);
|
||||
--white-30: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
@@ -30,7 +41,11 @@
|
||||
|
||||
body {
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
@@ -39,6 +54,9 @@
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
color 0.3s ease;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
@@ -69,16 +87,19 @@
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
background: var(--background-800);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
border 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .settings-panel {
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
@@ -93,17 +114,19 @@
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
background: var(--background-900);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
color: var(--text-light-05);
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
@@ -134,9 +157,10 @@
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
background: var(--background-900);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
@@ -176,7 +200,7 @@
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-800);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
@@ -186,8 +210,8 @@
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-900);
|
||||
box-shadow: 0 0 0 2px var(--white-10);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
@@ -231,7 +255,7 @@
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
background-color: var(--white-15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
@@ -243,14 +267,18 @@
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
background-color: var(--background-800);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.dark .toggle-slider:before {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
background-color: var(--white-30);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
@@ -288,14 +316,15 @@
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-15);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
@@ -372,10 +401,34 @@
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
// Theme detection based on system preferences
|
||||
function applySystemTheme() {
|
||||
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
||||
|
||||
function updateTheme(e) {
|
||||
if (e.matches) {
|
||||
document.documentElement.classList.add("dark");
|
||||
document.body.classList.add("dark");
|
||||
} else {
|
||||
document.documentElement.classList.remove("dark");
|
||||
document.body.classList.remove("dark");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply initial theme
|
||||
updateTheme(darkModeQuery);
|
||||
|
||||
// Listen for changes
|
||||
darkModeQuery.addEventListener("change", updateTheme);
|
||||
}
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Apply system theme immediately
|
||||
applySystemTheme();
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
|
||||
@@ -113,6 +113,23 @@
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function updateTitleBarTheme(isDark) {
|
||||
const titleBar = document.getElementById(TITLEBAR_ID);
|
||||
if (!titleBar) return;
|
||||
|
||||
if (isDark) {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
|
||||
} else {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
|
||||
}
|
||||
}
|
||||
|
||||
function buildTitleBar() {
|
||||
const titleBar = document.createElement("div");
|
||||
titleBar.id = TITLEBAR_ID;
|
||||
@@ -134,6 +151,11 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Apply initial styles matching current theme
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
// Apply styles matching Onyx design system with translucent glass effect
|
||||
titleBar.style.cssText = `
|
||||
position: fixed;
|
||||
@@ -156,8 +178,12 @@
|
||||
-webkit-backdrop-filter: blur(18px) saturate(180%);
|
||||
-webkit-app-region: drag;
|
||||
padding: 0 12px;
|
||||
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
|
||||
`;
|
||||
|
||||
// Apply correct theme
|
||||
updateTitleBarTheme(isDark);
|
||||
|
||||
return titleBar;
|
||||
}
|
||||
|
||||
@@ -168,6 +194,11 @@
|
||||
|
||||
const existing = document.getElementById(TITLEBAR_ID);
|
||||
if (existing?.parentElement === document.body) {
|
||||
// Update theme on existing titlebar
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,6 +209,14 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
|
||||
// Ensure theme is applied immediately after mount
|
||||
setTimeout(() => {
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
@@ -194,9 +233,66 @@
|
||||
}
|
||||
}
|
||||
|
||||
function observeThemeChanges() {
|
||||
let lastKnownTheme = null;
|
||||
|
||||
function checkAndUpdateTheme() {
|
||||
// Check both html and body for dark class (some apps use body)
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
if (lastKnownTheme !== isDark) {
|
||||
lastKnownTheme = isDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}
|
||||
}
|
||||
|
||||
// Immediate check on setup
|
||||
checkAndUpdateTheme();
|
||||
|
||||
// Watch for theme changes on the HTML element
|
||||
const themeObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
|
||||
themeObserver.observe(document.documentElement, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
|
||||
// Also observe body if it exists
|
||||
if (document.body) {
|
||||
const bodyObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
bodyObserver.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
}
|
||||
|
||||
// Also check periodically in case classList is manipulated directly
|
||||
// or the theme loads asynchronously after page load
|
||||
const intervalId = setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 300);
|
||||
|
||||
// Clean up after 30 seconds once theme should be stable
|
||||
setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
// But keep checking every 2 seconds for manual theme changes
|
||||
setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 2000);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
function init() {
|
||||
mountTitleBar();
|
||||
syncViewportHeight();
|
||||
observeThemeChanges();
|
||||
|
||||
window.addEventListener("resize", syncViewportHeight, { passive: true });
|
||||
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
|
||||
passive: true,
|
||||
|
||||
@@ -119,7 +119,7 @@ backend = [
|
||||
"shapely==2.0.6",
|
||||
"stripe==10.12.0",
|
||||
"urllib3==2.6.3",
|
||||
"mistune==0.8.4",
|
||||
"mistune==3.2.0",
|
||||
"sendgrid==6.12.5",
|
||||
"exa_py==1.15.4",
|
||||
"braintrust==0.3.9",
|
||||
@@ -142,7 +142,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.4.0",
|
||||
"onyx-devtools==0.6.2",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
26
uv.lock
generated
26
uv.lock
generated
@@ -3897,11 +3897,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mistune"
|
||||
version = "0.8.4"
|
||||
version = "3.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2d/a4/509f6e7783ddd35482feda27bc7f72e65b5e7dc910eca4ab2164daf9c577/mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", size = 58322, upload-time = "2018-10-11T06:59:27.908Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/09/ec/4b43dae793655b7d8a25f76119624350b4d65eb663459eb9603d7f1f0345/mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4", size = 16220, upload-time = "2018-10-11T06:59:26.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4766,7 +4766,7 @@ requires-dist = [
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
|
||||
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.25.0" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==0.8.4" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
|
||||
@@ -4775,7 +4775,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.4.0" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4878,20 +4878,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.4.0"
|
||||
version = "0.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/d8/f68d15c12d27d4525d10697ac7e2d67d6122fb59ccab219afb2973bc33ad/onyx_devtools-0.4.0-py3-none-any.whl", hash = "sha256:3eb821bce7ec8651d57e937d4d8483e1c2c4bc51df8cbab2dbcc05e3740ec96c", size = 2870841, upload-time = "2026-01-23T04:44:32.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/04/6376342389494b51fd89e554dfdaf0d3809b8d1473bc9b72abd2d7dba21e/onyx_devtools-0.4.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:144e518abad3031ffef189445a69356fca1da2a4fb40c7b8431550133bfc4eef", size = 2890308, upload-time = "2026-01-23T04:44:37.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/c1/859b32fb3eff7e67179d971ace36313ae64e7fc9a242b45e606138b0041f/onyx_devtools-0.4.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0cc74d561f08a9c894adf8de79855b4fc72eb70e823a75e29db7f625ad366bd7", size = 2696160, upload-time = "2026-01-23T04:44:30.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/1b/f1e3f574e9917779d22e3fcb28f8ac1888c250e7452a523f64a6ab8a1759/onyx_devtools-0.4.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d69de76a97d7f9ff8c473afffbf544a65265645d726f3d70cc12dbbd7e364222", size = 2602134, upload-time = "2026-01-23T04:44:31.716Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/4a/a5d11640fdc23c9bf0e8617ce13793a587e49a64be2d20badf7e9b045e0a/onyx_devtools-0.4.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:fa84980ce8830e35432831aadc19ff465dbc723605aa80c50e0debc58457b70f", size = 2870864, upload-time = "2026-01-23T04:44:31.5Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/9f/6a7e02fbf47bcaea4d02b0ed92bea6e2c09408be7654fb3b57a1ba9863f2/onyx_devtools-0.4.0-py3-none-win_amd64.whl", hash = "sha256:8451efe3e137157696decf8b60a19fb3f0c52ae9f2d9b7c5bc6e667900e7c61e", size = 2953545, upload-time = "2026-01-23T04:44:38.11Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/42/f7a5b99ade06d215fb99de41181d51a9a984f83afb15afa15ce79ecab635/onyx_devtools-0.4.0-py3-none-win_arm64.whl", hash = "sha256:53a5942c922d7049650e934c43f9c057d046f8d53bc68935ebf7e93baa29afc3", size = 2665984, upload-time = "2026-01-23T04:44:29.399Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { SVGProps } from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
|
||||
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 15 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
21
web/lib/opal/src/icons/branch.tsx
Normal file
21
web/lib/opal/src/icons/branch.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBranch = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M4.75001 5C5.71651 5 6.50001 4.2165 6.50001 3.25C6.50001 2.2835 5.7165 1.5 4.75 1.5C3.78351 1.5 3.00001 2.2835 3.00001 3.25C3.00001 4.2165 3.78351 5 4.75001 5ZM4.75001 5L4.75001 6.24999M4.75 11C3.7835 11 3 11.7835 3 12.75C3 13.7165 3.7835 14.5 4.75 14.5C5.7165 14.5 6.5 13.7165 6.5 12.75C6.5 11.7835 5.71649 11 4.75 11ZM4.75 11L4.75001 6.24999M10.5 8.74997C10.5 9.71646 11.2835 10.5 12.25 10.5C13.2165 10.5 14 9.71646 14 8.74997C14 7.78347 13.2165 7 12.25 7C11.2835 7 10.5 7.78347 10.5 8.74997ZM10.5 8.74997L7.25001 8.74999C5.8693 8.74999 4.75001 7.6307 4.75001 6.24999"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgBranch;
|
||||
16
web/lib/opal/src/icons/circle.tsx
Normal file
16
web/lib/opal/src/icons/circle.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgCircle = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<circle cx="8" cy="8" r="6" strokeWidth={1.5} />
|
||||
</svg>
|
||||
);
|
||||
export default SvgCircle;
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgClaude = (props: IconProps) => {
|
||||
const SvgClaude = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgClipboard = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
21
web/lib/opal/src/icons/download.tsx
Normal file
21
web/lib/opal/src/icons/download.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgDownload = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M14 10V12.6667C14 13.3929 13.3929 14 12.6667 14H3.33333C2.60711 14 2 13.3929 2 12.6667V10M4.66667 6.66667L8 10M8 10L11.3333 6.66667M8 10L8 2"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgDownload;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
|
||||
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
|
||||
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
|
||||
export { default as SvgBranch } from "@opal/icons/branch";
|
||||
export { default as SvgBubbleText } from "@opal/icons/bubble-text";
|
||||
export { default as SvgCalendar } from "@opal/icons/calendar";
|
||||
export { default as SvgCheck } from "@opal/icons/check";
|
||||
@@ -36,6 +37,7 @@ export { default as SvgChevronLeft } from "@opal/icons/chevron-left";
|
||||
export { default as SvgChevronRight } from "@opal/icons/chevron-right";
|
||||
export { default as SvgChevronUp } from "@opal/icons/chevron-up";
|
||||
export { default as SvgChevronUpSmall } from "@opal/icons/chevron-up-small";
|
||||
export { default as SvgCircle } from "@opal/icons/circle";
|
||||
export { default as SvgClaude } from "@opal/icons/claude";
|
||||
export { default as SvgClipboard } from "@opal/icons/clipboard";
|
||||
export { default as SvgClock } from "@opal/icons/clock";
|
||||
@@ -46,6 +48,7 @@ export { default as SvgCopy } from "@opal/icons/copy";
|
||||
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
|
||||
export { default as SvgCpu } from "@opal/icons/cpu";
|
||||
export { default as SvgDevKit } from "@opal/icons/dev-kit";
|
||||
export { default as SvgDownload } from "@opal/icons/download";
|
||||
export { default as SvgDiscordMono } from "@opal/icons/DiscordMono";
|
||||
export { default as SvgDownloadCloud } from "@opal/icons/download-cloud";
|
||||
export { default as SvgEdit } from "@opal/icons/edit";
|
||||
@@ -135,6 +138,7 @@ export { default as SvgStep3End } from "@opal/icons/step3-end";
|
||||
export { default as SvgStop } from "@opal/icons/stop";
|
||||
export { default as SvgStopCircle } from "@opal/icons/stop-circle";
|
||||
export { default as SvgSun } from "@opal/icons/sun";
|
||||
export { default as SvgTerminal } from "@opal/icons/terminal";
|
||||
export { default as SvgTerminalSmall } from "@opal/icons/terminal-small";
|
||||
export { default as SvgTextLinesSmall } from "@opal/icons/text-lines-small";
|
||||
export { default as SvgThumbsDown } from "@opal/icons/thumbs-down";
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const OnyxLogo = ({
|
||||
width = 24,
|
||||
height = 24,
|
||||
className,
|
||||
...props
|
||||
}: IconProps) => (
|
||||
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={width}
|
||||
height={height}
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
@@ -23,4 +17,4 @@ const OnyxLogo = ({
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default OnyxLogo;
|
||||
export default SvgOnyxLogo;
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgOpenAI = (props: IconProps) => {
|
||||
const SvgOpenAI = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
22
web/lib/opal/src/icons/terminal.tsx
Normal file
22
web/lib/opal/src/icons/terminal.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgTerminal = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2.66667 11.3333L6.66667 7.33331L2.66667 3.33331M8.00001 12.6666H13.3333"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgTerminal;
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgUserPlus = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgWallet = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -300,13 +300,7 @@ export default function Page() {
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={
|
||||
<SvgOnyxLogo
|
||||
width={32}
|
||||
height={32}
|
||||
className="my-auto stroke-text-04"
|
||||
/>
|
||||
}
|
||||
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</>
|
||||
|
||||
@@ -31,6 +31,7 @@ import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
@@ -135,7 +136,7 @@ function BedrockFormInternals({
|
||||
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<SelectorFormField
|
||||
@@ -176,7 +177,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
label="AWS Access Key ID"
|
||||
@@ -191,7 +192,7 @@ function BedrockFormInternals({
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
label="AWS Bedrock Long-term API Key"
|
||||
|
||||
@@ -131,10 +131,15 @@ export function CustomForm({
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(
|
||||
values.custom_config_list
|
||||
),
|
||||
|
||||
@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
|
||||
function OllamaFormContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
|
||||
existingLlmProvider?.model_configurations || []
|
||||
);
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -70,16 +71,25 @@ function OllamaFormContent({
|
||||
.then((data) => {
|
||||
if (data.error) {
|
||||
console.error("Error fetching models:", data.error);
|
||||
setAvailableModels([]);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setAvailableModels(data.models);
|
||||
setFetchedModels(data.models);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingModels(false);
|
||||
});
|
||||
}
|
||||
}, [formikProps.values.api_base]);
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
existingLlmProvider?.name,
|
||||
setFetchedModels,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
@@ -99,7 +109,7 @@ function OllamaFormContent({
|
||||
/>
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={availableModels}
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
|
||||
isLoading={isLoadingModels}
|
||||
@@ -125,6 +135,8 @@ export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
@@ -189,7 +201,10 @@ export function OllamaForm({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
@@ -205,6 +220,8 @@ export function OllamaForm({
|
||||
<OllamaFormContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
testError={testError}
|
||||
mutate={mutate}
|
||||
|
||||
@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo
|
||||
width={24}
|
||||
height={24}
|
||||
className="text-text-04 shrink-0"
|
||||
/>
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1168,7 +1168,7 @@ export default function Page() {
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={16} height={16} />
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
@@ -1381,7 +1381,7 @@ export default function Page() {
|
||||
} logo`,
|
||||
fallback:
|
||||
selectedContentProviderType === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
|
||||
<SvgOnyxLogo size={24} className="text-text-05" />
|
||||
) : undefined,
|
||||
size: 24,
|
||||
containerSize: 28,
|
||||
|
||||
@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
<BackButton />
|
||||
<div
|
||||
className="flex
|
||||
items-center
|
||||
|
||||
@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { ChatPopup } from "@/app/chat/components/ChatPopup";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextView from "@/components/chat/TextView";
|
||||
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
if (liveAssistant) {
|
||||
return liveAssistant.tools.some(
|
||||
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
return personaIncludesRetrieval(liveAssistant);
|
||||
}
|
||||
return false;
|
||||
}, [liveAssistant]);
|
||||
|
||||
@@ -643,6 +643,7 @@ export function useChatController({
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
let files = projectFilesToFileDescriptors(currentMessageFiles);
|
||||
let packets: Packet[] = [];
|
||||
let packetsVersion = 0;
|
||||
|
||||
let newUserMessageId: number | null = null;
|
||||
let newAssistantMessageId: number | null = null;
|
||||
@@ -729,7 +730,6 @@ export function useChatController({
|
||||
if (!packet) {
|
||||
continue;
|
||||
}
|
||||
console.debug("Packet:", JSON.stringify(packet));
|
||||
|
||||
// We've processed initial packets and are starting to stream content.
|
||||
// Transition from 'loading' to 'streaming'.
|
||||
@@ -800,8 +800,8 @@ export function useChatController({
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "obj")) {
|
||||
console.debug("Object packet:", JSON.stringify(packet));
|
||||
packets.push(packet as Packet);
|
||||
packetsVersion++;
|
||||
|
||||
// Check if the packet contains document information
|
||||
const packetObj = (packet as Packet).obj;
|
||||
@@ -859,6 +859,8 @@ export function useChatController({
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetsVersion: packetsVersion,
|
||||
packetCount: packets.length,
|
||||
},
|
||||
],
|
||||
// Pass the latest map state
|
||||
@@ -885,6 +887,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
},
|
||||
{
|
||||
nodeId: initialAssistantNode.nodeId,
|
||||
@@ -894,6 +897,7 @@ export function useChatController({
|
||||
toolCall: null,
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
packets: [],
|
||||
packetCount: 0,
|
||||
stackTrace: stackTrace,
|
||||
errorCode: errorCode,
|
||||
isRetryable: isRetryable,
|
||||
|
||||
@@ -139,6 +139,9 @@ export interface Message {
|
||||
|
||||
// new gen
|
||||
packets: Packet[];
|
||||
// Version counter for efficient memo comparison (increments with each packet)
|
||||
packetsVersion?: number;
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
|
||||
// cached values for easy access
|
||||
documents?: OnyxDocument[] | null;
|
||||
|
||||
@@ -74,6 +74,8 @@ export type RegenerationFactory = (regenerationRequest: {
|
||||
|
||||
export interface AIMessageProps {
|
||||
rawPackets: Packet[];
|
||||
// Version counter for efficient memo comparison (avoids array copying)
|
||||
packetsVersion?: number;
|
||||
chatState: FullChatState;
|
||||
nodeId: number;
|
||||
messageId?: number;
|
||||
@@ -88,8 +90,6 @@ export interface AIMessageProps {
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
// - `rawPackets.length` assumes packets are append-only. Could compare the last
|
||||
// packet or use a shallow comparison if packets can be modified in place.
|
||||
// - `chatState.docs`, `chatState.citations`, and `otherMessagesCanSwitchTo` use
|
||||
// reference equality. Shallow array/object comparison would be more robust if
|
||||
// these are recreated with the same values.
|
||||
@@ -98,7 +98,7 @@ function arePropsEqual(prev: AIMessageProps, next: AIMessageProps): boolean {
|
||||
prev.nodeId === next.nodeId &&
|
||||
prev.messageId === next.messageId &&
|
||||
prev.currentFeedback === next.currentFeedback &&
|
||||
prev.rawPackets.length === next.rawPackets.length &&
|
||||
prev.packetsVersion === next.packetsVersion &&
|
||||
prev.chatState.assistant?.id === next.chatState.assistant?.id &&
|
||||
prev.chatState.docs === next.chatState.docs &&
|
||||
prev.chatState.citations === next.chatState.citations &&
|
||||
|
||||
@@ -11,6 +11,7 @@ import { CitationMap } from "../../interfaces";
|
||||
export enum RenderType {
|
||||
HIGHLIGHT = "highlight",
|
||||
FULL = "full",
|
||||
COMPACT = "compact",
|
||||
}
|
||||
|
||||
export interface FullChatState {
|
||||
@@ -35,6 +36,9 @@ export interface RendererResult {
|
||||
// used for things that should just show text w/o an icon or header
|
||||
// e.g. ReasoningRenderer
|
||||
expandedText?: JSX.Element;
|
||||
|
||||
// Whether this renderer supports compact mode (collapse button shown only when true)
|
||||
supportsCompact?: boolean;
|
||||
}
|
||||
|
||||
export type MessageRenderer<
|
||||
@@ -48,5 +52,7 @@ export type MessageRenderer<
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
/** Whether this is the last step in the timeline (for connector line decisions) */
|
||||
isLastStep?: boolean;
|
||||
children: (result: RendererResult) => JSX.Element;
|
||||
}>;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user