Compare commits

...

42 Commits

Author SHA1 Message Date
Justin Tahara
ed46504a1a fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 14:22:34 -08:00
Nikolas Garza
7a24b34516 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 17:27:31 -08:00
dependabot[bot]
7a7ffa9051 chore(deps): Bump mistune from 0.8.4 to 3.1.4 in /backend (#6407)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 17:27:31 -08:00
Jamison Lahman
3053ab518c chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:26:55 -08:00
justin-tahara
be38d3500f Fixing mypy 2026-02-09 15:48:53 -08:00
Justin Tahara
753a3bc093 fix(posthog): Chat metrics for Cloud (#8278) 2026-02-09 15:48:53 -08:00
Raunak Bhagat
2ba8fafe78 fix: Add explicit sizings to icons (#8018) 2026-02-06 18:15:47 -08:00
Raunak Bhagat
b77b580ebd Cherry-pick card fix 2026-02-06 18:15:40 -08:00
Nikolas Garza
3eee98b932 fix: make it more clear how to add channels to fed slack config form (#8227) 2026-02-06 16:35:46 -08:00
Nikolas Garza
a97eb02fef fix(db): null out document set and persona ownership on user deletion (#8219) 2026-02-06 16:35:46 -08:00
Justin Tahara
c5061495a2 fix(ui): Inconsistent LLM Provider Logo (#8220) 2026-02-06 13:56:57 -08:00
Justin Tahara
c20b0789ae fix(ui): Additional LLM Config update (#8174) 2026-02-06 13:56:49 -08:00
Justin Tahara
d99848717b fix(ui): Ollama Model Selection (#8091) 2026-02-06 13:53:52 -08:00
Evan Lohn
aaca55c415 fix(salesforce): cleanup logic (#8175) 2026-02-06 13:52:46 -08:00
Justin Tahara
9d7ffd1e4a fix(ui): Updating Dropdown Modal component (#8033) 2026-02-06 11:39:48 -08:00
Justin Tahara
a249161827 chore(chat): Cleaning Error Codes + Tests (#8186) 2026-02-06 11:39:36 -08:00
Justin Tahara
e126346a91 fix(agents): Removing Label Dependency (#8189) 2026-02-06 11:03:16 -08:00
Justin Tahara
a96682fa73 fix(ui): Agent Saving with other people files (#8095) 2026-02-02 10:30:46 -08:00
Justin Tahara
3920371d56 feat(desktop): Ensure that UI reflects Light/Dark Toggle (#7684) 2026-02-02 10:30:36 -08:00
Wenxi Onyx
e5a257345c 2nd dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:17:12 -08:00
Wenxi Onyx
a49df511e2 dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:09:41 -08:00
Justin Tahara
d5d2a8a1a6 fix(desktop): Remove Global Shortcuts (#7914) 2026-01-30 13:46:26 -08:00
Justin Tahara
b2f46b264c fix(asana): Workspace Team ID mismatch (#7674) 2026-01-30 13:19:07 -08:00
Jamison Lahman
c6ad363fbd chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 15:52:53 -08:00
Jamison Lahman
e313119f9a fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 14:50:00 -08:00
Wenxi
3a2a542a03 fix: connector details back button should nav back (#7869) 2026-01-27 14:35:15 -08:00
Yuhong Sun
413aeba4a1 fix: Project Creation (#7851) 2026-01-27 14:34:59 -08:00
Wenxi
46028aa2bb fix: user count check (#7811) 2026-01-27 14:34:29 -08:00
Justin Tahara
454943c4a6 fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 14:33:40 -08:00
Justin Tahara
87946266de fix(redis): Adding more TTLs (#7886) 2026-01-27 14:32:14 -08:00
Jamison Lahman
144030c5ca chore(vscode): add non-clean seeded db restore (#7795) 2026-01-26 08:55:19 -08:00
SubashMohan
a557d76041 feat(ui): add new icons and enhance FadeDiv, Modal, Tabs, ExpandableTextDisplay (#7563)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 10:26:09 +00:00
SubashMohan
605e808158 fix(layout): adjust footer margin and prevent page refresh on chatsession drop (#7759) 2026-01-26 04:45:40 +00:00
roshan
8fec88c90d chore(deployment): remove no auth option from setup script (#7784) 2026-01-26 04:42:45 +00:00
Yuhong Sun
e54969a693 fix: LiteLLM Azure models don't stream (#7761) 2026-01-25 07:46:51 +00:00
Raunak Bhagat
1da2b2f28f fix: Some new fixes that were discovered by AI reviewers during 2.9-hotfixing (#7757)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 04:44:30 +00:00
Nikolas Garza
eb7b91e08e fix(tests): use crawler-friendly search query in Exa integration test (#7746) 2026-01-24 20:58:02 +00:00
Yuhong Sun
3339000968 fix: Spacing issue on Feedback (#7747) 2026-01-24 12:59:00 -08:00
Nikolas Garza
d9db849e94 fix(chat): prevent streaming text from appearing in bursts after citations (#7745) 2026-01-24 11:48:34 -08:00
Yuhong Sun
046408359c fix: Azure OpenAI Tool Calls (#7727) 2026-01-24 01:47:03 +00:00
acaprau
4b8cca190f feat(opensearch): Implement complete retrieval filtering (#7691) 2026-01-23 23:27:42 +00:00
Justin Tahara
52a312a63b feat: onyx discord bot - supervisord and kube deployment (#7706) 2026-01-23 20:55:06 +00:00
164 changed files with 7880 additions and 1468 deletions

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

39
.vscode/launch.json vendored
View File

@@ -149,6 +149,24 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Discord Bot",
"consoleName": "Discord Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/discord/client.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Discord Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",
@@ -587,6 +605,27 @@
"order": 0
}
},
{
"name": "Restore seeded database dump",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore seeded database dump (destructive)",
"type": "node",

View File

@@ -97,10 +97,14 @@ def get_access_for_documents(
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set.
"""Returns a list of ACL entries that the user has access to.
This is meant to be used downstream to filter out documents that the user
does not have access to. The user should have access to a document if at
least one entry in the document's ACL matches one entry in the returned set.
NOTE: These strings must be formatted in the same way as the output of
DocumentAccess::to_acl.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}

View File

@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
"""Converts the access state to a set of formatted ACL strings.
NOTE: When querying for documents, the supplied ACL filter strings must
be formatted in the same way as this function.
"""
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:

View File

@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_TTL = FENCE_TTL
logger = setup_logger()
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
# Create the Celery task
celery_app.send_task(

View File

@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
distinct_id=user.email if user else tenant_id,
event="user_message_sent",
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=(
user.email
if user and not getattr(user, "is_anonymous", False)
else tenant_id
),
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)

View File

@@ -341,6 +341,7 @@ class MilestoneRecordType(str, Enum):
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
USER_MESSAGE_SENT = "user_message_sent"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"

View File

@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.workspace_id = asana_workspace_id.strip()
if asana_project_ids:
project_ids = [
project_id.strip()
for project_id in asana_project_ids.split(",")
if project_id.strip()
]
self.project_ids_to_index = project_ids or None
else:
self.project_ids_to_index = None
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(

View File

@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
points to a list of values, each value generates a unique pair.
NOTE: Whatever formatting strategy is used here to generate a key-value
string must be replicated when constructing query filters.
Args:
metadata: The metadata dict to convert where values can be either a
string or a list of strings.

View File

@@ -6,6 +6,7 @@ import sys
import tempfile
import time
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import cast
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def _convert_to_metadata_value(value: Any) -> str | list[str]:
"""Convert a Salesforce field value to a valid metadata value.
Document metadata expects str | list[str], but Salesforce returns
various types (bool, float, int, etc.). This function ensures all
values are properly converted to strings.
"""
if isinstance(value, list):
return [str(item) for item in value]
return str(value)
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
# # gc.collect()
# return all_types
def _yield_doc_batches(
self,
sf_db: OnyxSalesforceSQLite,
type_to_processed: dict[str, int],
changed_ids_to_type: dict[str, str],
parent_types: set[str],
increment_parents_changed: Callable[[], None],
) -> GenerateDocumentsOutput:
""" """
docs_to_yield: list[Document] = []
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
parent_object.data[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
increment_parents_changed()
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
def _full_sync(
self,
temp_dir: str,
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
docs_to_yield: list[Document] = []
changed_ids_to_type: dict[str, str] = {}
parents_changed = 0
examined_ids = 0
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
f"records={num_records}"
)
# yield an empty list to keep the connector alive
yield docs_to_yield
new_ids = sf_db.update_from_csv(
object_type=object_type,
csv_download_path=csv_path,
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
)
# Step 3 - extract and index docs
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=ctx.parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = (
type_to_processed.get(parent_type, 0) + 1
)
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = parent_object.data[
sf_attribute
]
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
def increment_parents_changed() -> None:
nonlocal parents_changed
parents_changed += 1
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
yield from self._yield_doc_batches(
sf_db,
type_to_processed,
changed_ids_to_type,
ctx.parent_types,
increment_parents_changed,
)
except Exception:
logger.exception("Unexpected exception")
raise
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
if sf_attribute in record:
doc.metadata[canonical_attribute] = record[sf_attribute]
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
record[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
return return_context
def load_from_state(self) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._full_sync(temp_dir)
# nuke the db since we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
os.remove(sqlite_db_path)
return self._full_sync(BASE_DATA_PATH)
# Always use a temp directory for SQLite - the database is rebuilt
# from scratch each time via CSV downloads, so there's no caching benefit
# from persisting it. Using temp dirs also avoids collisions between
# multiple CC pairs and eliminates stale WAL/SHM file issues.
# TODO(evan): make this thing checkpointed and persist/load db from filestore
with tempfile.TemporaryDirectory() as temp_dir:
yield from self._full_sync(temp_dir)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll source will synchronize updated parent objects one by one."""
if start == 0:
# nuke the db if we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
)
os.remove(sqlite_db_path)
return self._delta_sync(BASE_DATA_PATH, start, end)
# Always use a temp directory - see comment in load_from_state()
with tempfile.TemporaryDirectory() as temp_dir:
return self._delta_sync(temp_dir, start, end)
yield from self._delta_sync(temp_dir, start, end)
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
logger = setup_logger()
SQLITE_DISK_IO_ERROR = "disk I/O error"
class OnyxSalesforceSQLite:
"""Notes on context management using 'with self.conn':
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
def apply_schema(self) -> None:
"""Initialize the SQLite database with required tables if they don't exist.
Non-destructive operation.
Non-destructive operation. If a disk I/O error is encountered (often due
to stale WAL/SHM files from a previous crash), this method will attempt
to recover by removing the corrupted files and recreating the database.
"""
try:
self._apply_schema_impl()
except sqlite3.OperationalError as e:
if SQLITE_DISK_IO_ERROR not in str(e):
raise
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
self._recover_from_corruption()
self._apply_schema_impl()
def _recover_from_corruption(self) -> None:
"""Recover from SQLite corruption by removing all database files and reconnecting."""
logger.info(f"Removing corrupted SQLite files: {self.filename}")
# Close existing connection
self.close()
# Remove all SQLite files (main db, WAL, SHM)
remove_sqlite_db_files(self.filename)
# Reconnect - this will create a fresh database
self.connect()
logger.info("SQLite recovery complete, fresh database created")
def _apply_schema_impl(self) -> None:
"""Internal implementation of apply_schema."""
if self._conn is None:
raise RuntimeError("Database connection is closed")

View File

@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
return os.path.join(directory, "salesforce_db.sqlite")
def remove_sqlite_db_files(db_path: str) -> None:
"""Remove SQLite database and all associated files (WAL, SHM).
SQLite in WAL mode creates additional files:
- .sqlite-wal: Write-ahead log
- .sqlite-shm: Shared memory file
If these files become stale (e.g., after a crash), they can cause
'disk I/O error' when trying to open the database. This function
ensures all related files are removed.
"""
files_to_remove = [
db_path,
f"{db_path}-wal",
f"{db_path}-shm",
]
for file_path in files_to_remove:
if os.path.exists(file_path):
os.remove(file_path)
# NOTE: only used with shelves, deprecated at this point
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)

View File

@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
class IndexFilters(BaseFilters, UserFileFilters):
# NOTE: These strings must be formatted in the same way as the output of
# DocumentAccess::to_acl.
access_control_list: list[str] | None
tenant_id: str | None = None

View File

@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
"Persona",
secondary=Persona__PersonaLabel.__table__,
back_populates="labels",
cascade="all, delete-orphan",
single_parent=True,
)

View File

@@ -917,7 +917,9 @@ def upsert_persona(
existing_persona.icon_name = icon_name
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None

View File

@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
@@ -327,6 +329,15 @@ def delete_user_from_db(
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
# Null out ownership on document sets and personas so they're
# preserved for other users instead of being cascade-deleted
db_session.query(DocumentSet).filter(
DocumentSet.user_id == user_to_delete.id
).update({DocumentSet.user_id: None})
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
{Persona.user_id: None}
)
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()

View File

@@ -28,8 +28,8 @@ of "minimum value clipping".
## On time decay and boosting
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
Same logic applies to additive boosting.
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
being fetched and returned to the user. But there are other issues of including these:
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
contents. If there are lots of updates, this may miss
contents. If there are lots of updates, this may miss.
- There is not a good way to normalize this field, the best is to clip it on the bottom.
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing

View File

@@ -3,7 +3,9 @@ from typing import Any
import httpx
from onyx.access.models import DocumentAccess
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger(__name__)
def generate_opensearch_filtered_access_control_list(
access: DocumentAccess,
) -> list[str]:
"""Generates an access control list with PUBLIC_DOC_PAT removed.
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
"""
access_control_list = access.to_acl()
access_control_list.discard(PUBLIC_DOC_PAT)
return list(access_control_list)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -152,10 +166,9 @@ def _convert_onyx_chunk_to_opensearch_document(
metadata_suffix=chunk.metadata_suffix_keyword,
last_updated=chunk.source_document.doc_updated_at,
public=chunk.access.is_public,
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
access_control_list=list(chunk.access.to_acl()),
access_control_list=generate_opensearch_filtered_access_control_list(
chunk.access
),
global_boost=chunk.boost,
semantic_identifier=chunk.source_document.semantic_identifier,
image_file_id=chunk.image_file_id,
@@ -578,8 +591,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
# here so we don't have to think about passing in the
# appropriate types into this dict.
if update_request.access is not None:
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
update_request.access.to_acl()
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
generate_opensearch_filtered_access_control_list(
update_request.access
)
)
if update_request.document_sets is not None:
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
@@ -625,13 +640,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
def id_based_retrieval(
self,
chunk_requests: list[DocumentSectionRequest],
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
filters: IndexFilters,
# TODO(andrei): Remove this from the new interface at some point; we
# should not be exposing this.
batch_retrieval: bool = False,
# TODO(andrei): Add a param for whether to retrieve hidden docs.
) -> list[InferenceChunk]:
"""
TODO(andrei): Consider implementing this method to retrieve on document
@@ -646,6 +659,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
query_body = DocumentQuery.get_from_document_id_query(
document_id=chunk_request.document_id,
tenant_state=self._tenant_state,
index_filters=filters,
include_hidden=False,
max_chunk_size=chunk_request.max_chunk_size,
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
@@ -672,9 +687,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
query_embedding: Embedding,
final_keywords: list[str] | None,
query_type: QueryType,
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
filters: IndexFilters,
num_to_retrieve: int,
offset: int = 0,
@@ -688,6 +700,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
num_candidates=1000, # TODO(andrei): Magic number.
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
body=query_body,

View File

@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
return serialized_exclude_none
@field_serializer("last_updated", mode="wrap")
def serialize_datetime_fields_to_epoch_millis(
def serialize_datetime_fields_to_epoch_seconds(
self, value: datetime | None, handler: SerializerFunctionWrapHandler
) -> int | None:
"""
Serializes datetime fields to milliseconds since the Unix epoch.
Serializes datetime fields to seconds since the Unix epoch.
If there is no datetime, returns None.
"""
if value is None:
return None
value = set_or_convert_timezone_to_utc(value)
# timestamp returns a float in seconds so convert to millis.
return int(value.timestamp() * 1000)
return int(value.timestamp())
@field_validator("last_updated", mode="before")
@classmethod
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
"""Parses milliseconds since the Unix epoch to a datetime object.
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
"""Parses seconds since the Unix epoch to a datetime object.
If the input is None, returns None.
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
raise ValueError(
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
)
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
return datetime.fromtimestamp(value, tz=timezone.utc)
@field_serializer("tenant_id", mode="wrap")
def serialize_tenant_state(
@@ -354,11 +353,9 @@ class DocumentSchema:
},
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
# seconds here not millis.
LAST_UPDATED_FIELD_NAME: {
"type": "date",
"format": "epoch_millis",
"format": "epoch_second",
# For some reason date defaults to False, even though it
# would make sense to sort by date.
"doc_values": True,
@@ -366,14 +363,21 @@ class DocumentSchema:
# Access control fields.
# Whether the doc is public. Could have fallen under access
# control list but is such a broad and critical filter that it
# is its own field.
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
# should have no effect on queries.
PUBLIC_FIELD_NAME: {"type": "boolean"},
# Access control list for the doc, excluding public access,
# which is covered above.
# If a user's access set contains at least one entry from this
# set, the user should be able to retrieve this document. This
# only applies if public is set to false; public non-hidden
# documents are always visible to anyone in a given tenancy
# regardless of this field.
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
# Whether the doc is hidden from search results. Should clobber
# all other search filters; up to search implementations to
# guarantee this.
# Whether the doc is hidden from search results.
# Should clobber all other access search filters, namely
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
# search implementations to guarantee this.
HIDDEN_FIELD_NAME: {"type": "boolean"},
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
# This field is only used for displaying a useful name for the
@@ -447,7 +451,6 @@ class DocumentSchema:
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
# The maximum number of tokens this chunk's content can hold.
# TODO(andrei): Can we generalize this to embedding type?
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
},
}

View File

@@ -1,21 +1,36 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
@@ -91,6 +106,11 @@ assert (
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
class DocumentQuery:
"""
@@ -103,6 +123,8 @@ class DocumentQuery:
def get_from_document_id_query(
document_id: str,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
max_chunk_size: int,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -120,6 +142,8 @@ class DocumentQuery:
document_id: Onyx document ID. Notably not an OpenSearch document
ID, which points to what Onyx would refer to as a chunk.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the document retrieval query.
include_hidden: Whether to include hidden documents.
max_chunk_size: Document chunks are categorized by the maximum
number of tokens they can hold. This parameter specifies the
maximum size category of document chunks to retrieve.
@@ -136,28 +160,21 @@ class DocumentQuery:
Returns:
A dictionary representing the final ID search query.
"""
filter_clauses: list[dict[str, Any]] = [
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.multitenant:
# TODO(andrei): Fix tenant stuff.
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
if min_chunk_index is not None or max_chunk_index is not None:
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
if min_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
if max_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
filter_clauses.append(range_clause)
filter_clauses.append(
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
filter_clauses = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
max_chunk_size=max_chunk_size,
document_id=document_id,
)
final_get_ids_query: dict[str, Any] = {
"query": {"bool": {"filter": filter_clauses}},
# We include this to make sure OpenSearch does not revert to
@@ -195,15 +212,22 @@ class DocumentQuery:
Returns:
A dictionary representing the final delete query.
"""
filter_clauses: list[dict[str, Any]] = [
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.multitenant:
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
filter_clauses = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
# Delete hidden docs too.
include_hidden=True,
access_control_list=None,
source_types=[],
tags=[],
document_sets=[],
user_file_ids=[],
project_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
max_chunk_size=None,
document_id=document_id,
)
final_delete_query: dict[str, Any] = {
"query": {"bool": {"filter": filter_clauses}},
}
@@ -217,19 +241,25 @@ class DocumentQuery:
num_candidates: int,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final hybrid search query.
This query can be directly supplied to the OpenSearch client.
NOTE: This query can be directly supplied to the OpenSearch client, but
it MUST be supplied in addition to a search pipeline. The results from
hybrid search are not meaningful without that step.
Args:
query_text: The text to query for.
query_vector: The vector embedding of the text to query for.
num_candidates: The number of candidates to consider for vector
num_candidates: The number of neighbors to consider for vector
similarity search. Generally more candidates improves search
quality at the cost of performance.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the hybrid search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final hybrid search query.
@@ -243,31 +273,47 @@ class DocumentQuery:
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, num_candidates
)
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
)
match_highlights_configuration = (
DocumentQuery._get_match_highlights_configuration()
)
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
hybrid_search_query: dict[str, Any] = {
"bool": {
"must": [
{
"hybrid": {
"queries": hybrid_search_subqueries,
}
}
],
# TODO(andrei): When revisiting our hybrid query logic see if
# this needs to be nested one level down.
"filter": hybrid_search_filters,
"hybrid": {
"queries": hybrid_search_subqueries,
# Applied to all the sub-queries. Source:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
}
return final_hybrid_search_body
@staticmethod
@@ -294,7 +340,8 @@ class DocumentQuery:
pipeline.
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
in a single hybrid query.
in a single hybrid query. Source:
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
Args:
query_text: The text of the query to search for.
@@ -305,6 +352,7 @@ class DocumentQuery:
hybrid_search_queries: list[dict[str, Any]] = [
{
"knn": {
# Match on semantic similarity of the title.
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
@@ -313,6 +361,7 @@ class DocumentQuery:
},
{
"knn": {
# Match on semantic similarity of the content.
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
@@ -322,36 +371,273 @@ class DocumentQuery:
{
"multi_match": {
"query": query_text,
# TODO(andrei): Ask Yuhong do we want this?
# Either fuzzy match on the analyzed title (boosted 2x), or
# exact match on exact title keywords (no OpenSearch
# analysis done on the title). See
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
# Returns the score of the best match of the fields above.
# See
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
"type": "best_fields",
}
},
# Fuzzy match on the OpenSearch-analyzed content. See
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
# Exact match on the OpenSearch-analyzed content. See
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
]
return hybrid_search_queries
@staticmethod
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
"""Returns filters for hybrid search.
def _get_search_filters(
tenant_state: TenantState,
include_hidden: bool,
access_control_list: list[str] | None,
source_types: list[DocumentSource],
tags: list[Tag],
document_sets: list[str],
user_file_ids: list[UUID],
project_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
max_chunk_size: int | None = None,
document_id: str | None = None,
) -> list[dict[str, Any]]:
"""Returns filters to be passed into the "filter" key of a search query.
For now only fetches public and not hidden documents.
The "filter" key applies a logical AND operator to its elements, so
every subfilter must evaluate to true in order for the document to be
retrieved. This function returns a list of such subfilters.
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
The return of this function is not sufficient to be directly supplied to
the OpenSearch client. See get_hybrid_search_query.
Args:
tenant_state: Tenant state containing the tenant ID.
include_hidden: Whether to include hidden documents.
access_control_list: Access control list for the documents to
retrieve. If None, there is no restriction on the documents that
can be retrieved. If not None, only public documents can be
retrieved, or non-public documents where at least one acl
provided here is present in the document's acl list.
source_types: If supplied, only documents of one of these source
types will be retrieved.
tags: If supplied, only documents with an entry in their metadata
list corresponding to a tag will be retrieved.
document_sets: If supplied, only documents with at least one
document set ID from this list will be retrieved.
user_file_ids: If supplied, only document IDs in this list will be
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
updated time, we assume some default age of
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
updated.
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
None, no minimum chunk index will be applied.
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
None, no maximum chunk index will be applied.
max_chunk_size: The type of chunk to retrieve, specified by the
maximum number of tokens it can hold. If None, no filter will be
applied for this. Defaults to None.
NOTE: See DocumentChunk.max_chunk_size.
document_id: The document ID to retrieve. If None, no filter will be
applied for this. Defaults to None.
WARNING: This filters on the same property as user_file_ids.
Although it would never make sense to supply both, note that if
user_file_ids is supplied and does not contain document_id, no
matches will be retrieved.
TODO(andrei): Add ACL filters and stuff.
Returns:
A list of filters to be passed into the "filter" key of a search
query.
"""
hybrid_search_filters: list[dict[str, Any]] = [
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
]
def _get_acl_visibility_filter(
access_control_list: list[str],
) -> dict[str, Any]:
# Logical OR operator on its elements.
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
acl_visibility_filter["bool"]["should"].append(
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
)
for acl in access_control_list:
acl_subclause: dict[str, Any] = {
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
}
acl_visibility_filter["bool"]["should"].append(acl_subclause)
return acl_visibility_filter
def _get_source_type_filter(
source_types: list[DocumentSource],
) -> dict[str, Any]:
# Logical OR operator on its elements.
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
for source_type in source_types:
source_type_filter["bool"]["should"].append(
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
)
return source_type_filter
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
# Logical OR operator on its elements.
tag_filter: dict[str, Any] = {"bool": {"should": []}}
for tag in tags:
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
tag_filter["bool"]["should"].append(
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
)
return tag_filter
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
# Logical OR operator on its elements.
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
for document_set in document_sets:
document_set_filter["bool"]["should"].append(
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
)
return document_set_filter
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
# Logical OR operator on its elements.
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
for user_file_id in user_file_ids:
user_file_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
)
return user_file_id_filter
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
user_project_filter["bool"]["should"].append(
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
)
return user_project_filter
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
# Convert to UTC if not already so the cutoff is comparable to the
# document data.
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
# Logical OR operator on its elements.
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
time_cutoff_filter["bool"]["should"].append(
{
"range": {
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
}
}
)
if time_cutoff < datetime.now(timezone.utc) - timedelta(
days=ASSUMED_DOCUMENT_AGE_DAYS
):
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
# ago, we include documents which have no
# LAST_UPDATED_FIELD_NAME value.
time_cutoff_filter["bool"]["should"].append(
{
"bool": {
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
}
}
)
return time_cutoff_filter
def _get_chunk_index_filter(
min_chunk_index: int | None, max_chunk_index: int | None
) -> dict[str, Any]:
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
if min_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
if max_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
return range_clause
filter_clauses: list[dict[str, Any]] = []
if not include_hidden:
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
if access_control_list is not None:
# If an access control list is provided, the caller can only
# retrieve public documents, and non-public documents where at least
# one acl provided here is present in the document's acl list. If
# there is explicitly no list provided, we make no restrictions on
# the documents that can be retrieved.
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
if source_types:
# If at least one source type is provided, the caller will only
# retrieve documents whose source type is present in this input
# list.
filter_clauses.append(_get_source_type_filter(source_types))
if tags:
# If at least one tag is provided, the caller will only retrieve
# documents where at least one tag provided here is present in the
# document's metadata list.
filter_clauses.append(_get_tag_filter(tags))
if document_sets:
# If at least one document set is provided, the caller will only
# retrieve documents where at least one document set provided here
# is present in the document's document sets list.
filter_clauses.append(_get_document_set_filter(document_sets))
if user_file_ids:
# If at least one user file ID is provided, the caller will only
# retrieve documents where the document ID is in this input list of
# file IDs. Note that these IDs correspond to Onyx documents whereas
# the entries retrieved from the document index correspond to Onyx
# document chunks.
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
if project_id is not None:
# If a project ID is provided, the caller will only retrieve
# documents where the project ID provided here is present in the
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve
# documents where the document was last updated at or after the time
# cutoff. For documents which do not have a value for
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
# purposes of time cutoff.
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
if min_chunk_index is not None or max_chunk_index is not None:
filter_clauses.append(
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
)
if document_id is not None:
# WARNING: If user_file_ids has elements and if none of them are
# document_id, no matches will be retrieved.
filter_clauses.append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
)
if max_chunk_size is not None:
filter_clauses.append(
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
)
if tenant_state.multitenant:
hybrid_search_filters.append(
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
return hybrid_search_filters
return filter_clauses
@staticmethod
def _get_match_highlights_configuration() -> dict[str, Any]:
@@ -378,4 +664,5 @@ class DocumentQuery:
}
}
}
return match_highlights_configuration

View File

@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.completed":
# Final event signaling all output items (including parallel tool calls) are done
# Check if we already received tool calls via streaming events
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
# response.completed event so we need to throw it out here or there are duplicate tool calls.
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
response_data = parsed_chunk.get("response", {})
# Determine finish reason based on response content
finish_reason = "stop"
if response_data.get("output"):
for item in response_data["output"]:
if isinstance(item, dict) and item.get("type") == "function_call":
finish_reason = "tool_calls"
break
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason=finish_reason,
usage=None,
output_items = response_data.get("output", [])
# Check if there are function_call items in the output
has_function_calls = any(
isinstance(item, dict) and item.get("type") == "function_call"
for item in output_items
)
if has_function_calls and not has_streamed_tool_calls:
# Azure's Responses API returns all tool calls in response.completed
# without streaming them incrementally. Extract them here.
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
tool_calls = []
for idx, item in enumerate(output_items):
if isinstance(item, dict) and item.get("type") == "function_call":
tool_calls.append(
ChatCompletionToolCallChunk(
id=item.get("call_id"),
index=idx,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=item.get("name"),
arguments=item.get("arguments", ""),
),
)
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=0,
delta=Delta(tool_calls=tool_calls),
finish_reason="tool_calls",
)
]
)
elif has_function_calls:
# Tool calls were already streamed, just signal completion
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
else:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=None,
)
else:
pass
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
def _patch_azure_responses_should_fake_stream() -> None:
"""
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
not in its database. This causes Azure custom model deployments to buffer the entire
response before yielding, resulting in poor time-to-first-token.
Azure's Responses API supports native streaming, so we override this to always use
real streaming (SyncResponsesAPIStreamingIterator).
"""
from litellm.llms.azure.responses.transformation import (
AzureOpenAIResponsesAPIConfig,
)
if (
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
== "_patched_should_fake_stream"
):
return
def _patched_should_fake_stream(
self: Any,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
# Azure Responses API supports native streaming - never fake it
return False
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_openai_responses_transform_response()
_patch_azure_responses_should_fake_stream()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -0,0 +1,287 @@
# Discord Bot Multitenant Architecture
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
## Overview
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
```
┌─────────────────────────────────────────────────────────────────────┐
│ OnyxDiscordClient │
│ │
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
│ │ │ │ │ │
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
│ │ tenant_id → api_key │ │ message, │ │
│ │ │ │ api_key=<per-tenant>, │ │
│ └─────────────────────────┘ │ persona_id=... │ │
│ │ ) │ │
│ └─────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
```
---
## Component Details
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
The `DiscordCacheManager` maintains two critical in-memory mappings:
```python
class DiscordCacheManager:
_guild_tenants: dict[int, str] # guild_id → tenant_id
_api_keys: dict[str, str] # tenant_id → api_key
_lock: asyncio.Lock # Concurrency control
```
#### Key Responsibilities
| Function | Purpose |
|----------|---------|
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
| `refresh_all()` | Full cache rebuild from database |
| `refresh_guild()` | Incremental update for single guild |
#### API Key Provisioning Strategy
API keys are **lazily provisioned** - only created when first needed:
```python
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
needs_key = tenant_id not in self._api_keys
with get_session_with_tenant(tenant_id) as db:
# Load guild configs
configs = get_discord_bot_configs(db)
guild_ids = [c.guild_id for c in configs if c.enabled]
# Only provision API key if not already cached
api_key = None
if needs_key:
api_key = get_or_create_discord_service_api_key(db, tenant_id)
return guild_ids, api_key
```
This optimization avoids repeated database calls for API key generation.
#### Concurrency Control
All write operations acquire an async lock to prevent race conditions:
```python
async def refresh_all(self) -> None:
async with self._lock:
# Safe to modify _guild_tenants and _api_keys
for tenant_id in get_all_tenant_ids():
guild_ids, api_key = await self._load_tenant_data(tenant_id)
# Update mappings...
```
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
---
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
#### Key Design: Per-Request API Key Injection
```python
class OnyxAPIClient:
async def send_chat_message(
self,
message: str,
api_key: str, # Injected per-request
persona_id: int | None,
...
) -> ChatFullResponse:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
}
# Make request...
```
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
```python
# Same client, different tenants
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
```
---
## Coordination Flow
### Message Processing Pipeline
When a Discord message arrives, the client coordinates cache and API client:
```python
async def on_message(self, message: Message) -> None:
guild_id = message.guild.id
# Step 1: Cache lookup - guild → tenant
tenant_id = self.cache.get_tenant(guild_id)
if not tenant_id:
return # Guild not registered
# Step 2: Cache lookup - tenant → API key
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key for tenant {tenant_id}")
return
# Step 3: API call with tenant-specific credentials
await process_chat_message(
message=message,
api_key=api_key, # Tenant-specific
persona_id=persona_id, # Tenant-specific
api_client=self.api_client,
)
```
### Startup Sequence
```python
async def setup_hook(self) -> None:
# 1. Initialize API client (create aiohttp session)
await self.api_client.initialize()
# 2. Populate cache with all tenants
await self.cache.refresh_all()
# 3. Start background refresh task
self._cache_refresh_task = self.loop.create_task(
self._periodic_cache_refresh() # Every 60 seconds
)
```
### Shutdown Sequence
```python
async def close(self) -> None:
# 1. Cancel background refresh
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
# 2. Close Discord connection
await super().close()
# 3. Close API client session
await self.api_client.close()
# 4. Clear cache
self.cache.clear()
```
---
## Tenant Isolation Mechanisms
### 1. Per-Tenant API Keys
Each tenant has a dedicated service API key:
```python
# backend/onyx/db/discord_bot.py
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
existing = get_discord_service_api_key(db_session)
if existing:
return regenerate_key(existing)
# Create LIMITED role key (chat-only permissions)
return insert_api_key(
db_session=db_session,
api_key_args=APIKeyArgs(
name=DISCORD_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED, # Minimal permissions
),
user_id=None, # Service account (system-owned)
).api_key
```
### 2. Database Context Variables
The cache uses context variables for proper tenant-scoped DB sessions:
```python
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
with get_session_with_tenant(tenant_id) as db:
# All DB operations scoped to this tenant
...
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
```
### 3. Enterprise Gating Support
Gated tenants are filtered during cache refresh:
```python
gated_tenants = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
for tenant_id in get_all_tenant_ids():
if tenant_id in gated_tenants:
continue # Skip gated tenants
```
---
## Cache Refresh Strategy
| Trigger | Method | Scope |
|---------|--------|-------|
| Startup | `refresh_all()` | All tenants |
| Periodic (60s) | `refresh_all()` | All tenants |
| Guild registration | `refresh_guild()` | Single tenant |
### Error Handling
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
- **Missing API key**: Bot silently ignores messages from that guild
- **Network errors**: Logged, cache continues with stale data until next refresh
---
## Key Design Insights
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
---
## File References
| Component | Path |
|-----------|------|
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,12 +1,149 @@
from mistune import Markdown # type: ignore[import-untyped]
from mistune import Renderer
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
return Markdown(renderer=SlackRenderer()).render(message)
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result.rstrip("\n")
class SlackRenderer(Renderer):
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
def strong(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
lines = text.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
def link(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
if text:
return f"<{escaped_url}|{text}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
return f"<{escaped_url}|{title}>"
return f"<{escaped_url}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
def image(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)
return f"{text}\n\n"

View File

@@ -32,6 +32,7 @@ class RedisConnectorDelete:
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
TASKSET_TTL = FENCE_TTL
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
self.redis.sadd(self.taskset_key, custom_task_id)
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
# Priority on sync's triggered by new indexing should be medium
celery_app.send_task(

View File

@@ -45,6 +45,7 @@ class RedisConnectorPrune:
) # connectorpruning_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
TASKSET_TTL = FENCE_TTL
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
# used to signal the overall workflow is still active
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
# add to the tracking taskset in redis BEFORE creating the celery task.
self.redis.sadd(self.taskset_key, custom_task_id)
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(

View File

@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = PREFIX + "_taskset"
TASKSET_TTL = FENCE_TTL
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,

View File

@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_PREFIX = PREFIX + "_taskset"
TASKSET_TTL = FENCE_TTL
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,

View File

@@ -84,7 +84,8 @@ def patch_document_set(
user=user,
target_group_ids=document_set_update_request.groups,
object_is_public=document_set_update_request.is_public,
object_is_owned_by_user=user and document_set.user_id == user.id,
object_is_owned_by_user=user
and (document_set.user_id is None or document_set.user_id == user.id),
)
try:
update_document_set(
@@ -125,7 +126,8 @@ def delete_document_set(
db_session=db_session,
user=user,
object_is_public=document_set.is_public,
object_is_owned_by_user=user and document_set.user_id == user.id,
object_is_owned_by_user=user
and (document_set.user_id is None or document_set.user_id == user.id),
)
try:

View File

@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
@router.get("/", tags=PUBLIC_API_TAGS)
@router.get("", tags=PUBLIC_API_TAGS)
def get_projects(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),

View File

@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import user_needs_to_be_verified
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
from onyx.configs.constants import AuthType
from onyx.configs.constants import DEV_VERSION_PATTERN
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import STABLE_VERSION_PATTERN
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
async def get_auth_type() -> AuthTypeResponse:
user_count = await get_user_count()
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
# associated with until they auth.
has_users = True
if AUTH_TYPE != AuthType.CLOUD:
user_count = await get_user_count()
has_users = user_count > 0
return AuthTypeResponse(
auth_type=AUTH_TYPE,
requires_verification=user_needs_to_be_verified(),
anonymous_user_enabled=anonymous_user_enabled(),
password_min_length=PASSWORD_MIN_LENGTH,
has_users=user_count > 0,
has_users=has_users,
)

View File

@@ -410,26 +410,20 @@ def list_llm_provider_basics(
all_providers = fetch_existing_llm_providers(db_session)
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
is_admin = user and user.role == UserRole.ADMIN
is_admin = user is not None and user.role == UserRole.ADMIN
accessible_providers = []
for provider in all_providers:
# Include all public providers
if provider.is_public:
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
continue
# Include restricted providers user has access to via groups
if is_admin:
# Admins see all providers
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif provider.groups:
# User must be in at least one of the provider's groups
if user_group_ids.intersection({g.id for g in provider.groups}):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif not provider.personas:
# No restrictions = accessible
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
end_time = datetime.now(timezone.utc)

View File

@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import remove_chat_message_feedback
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -266,7 +267,35 @@ def get_chat_session(
include_deleted=include_deleted,
)
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
try:
# If we failed to get a chat session, try to retrieve the session with
# less restrictive filters in order to identify what exactly mismatched
# so we can bubble up an accurate error code andmessage.
existing_chat_session = get_chat_session_by_id(
chat_session_id=session_id,
user_id=None,
db_session=db_session,
is_shared=False,
include_deleted=True,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
if not include_deleted and existing_chat_session.deleted:
raise HTTPException(status_code=404, detail="Chat session has been deleted")
if is_shared:
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
raise HTTPException(
status_code=403, detail="Chat session is not shared"
)
elif user_id is not None and existing_chat_session.user_id not in (
user_id,
None,
):
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=404, detail="Chat session not found")
# for chat-seeding: if the session is unassigned, assign it now. This is done here
# to avoid another back and forth between FE -> BE before starting the first

View File

@@ -580,7 +580,7 @@ def translate_assistant_message_to_packets(
# Determine stop reason - check if message indicates user cancelled
stop_reason: str | None = None
if chat_message.message:
if "Generation was stopped" in chat_message.message:
if "generation was stopped" in chat_message.message.lower():
stop_reason = "user_cancelled"
# Add overall stop packet at the end

View File

@@ -573,7 +573,7 @@ mcp==1.25.0
# onyx
mdurl==0.1.2
# via markdown-it-py
mistune==0.8.4
mistune==3.2.0
# via onyx
more-itertools==10.8.0
# via

View File

@@ -298,7 +298,7 @@ numpy==2.4.1
# pandas-stubs
# shapely
# voyageai
onyx-devtools==0.4.0
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via

View File

@@ -191,6 +191,18 @@ autorestart=true
startretries=5
startsecs=60
# Listens for Discord messages and responds with answers
# for all guilds/channels that the OnyxBot has been added to.
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
[program:discord_bot]
command=python onyx/onyxbot/discord/client.py
stdout_logfile=/var/log/discord_bot.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startretries=5
startsecs=60
# Pushes all logs from the above programs to stdout
# No log rotation here, since it's stdout it's handled by the Docker container logging
[program:log-redirect-handler]
@@ -206,6 +218,7 @@ command=tail -qF
/var/log/celery_worker_user_file_processing.log
/var/log/celery_worker_docfetching.log
/var/log/slack_bot.log
/var/log/discord_bot.log
/var/log/supervisord_watchdog_celery_beat.log
/var/log/mcp_server.log
/var/log/mcp_server.err.log

View File

@@ -8,14 +8,22 @@ import re
import uuid
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import pytest
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.opensearch_document_index import (
generate_opensearch_filtered_access_control_list,
)
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentSchema
@@ -42,14 +50,22 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
def _create_test_document_chunk(
document_id: str,
chunk_index: int,
content: str,
tenant_state: TenantState,
chunk_index: int = 0,
content_vector: list[float] | None = None,
title: str | None = None,
title_vector: list[float] | None = None,
public: bool = True,
hidden: bool = False,
document_access: DocumentAccess = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
),
source_type: DocumentSource = DocumentSource.FILE,
last_updated: datetime | None = None,
) -> DocumentChunk:
if content_vector is None:
# Generate dummy vector - 128 dimensions for fast testing.
@@ -59,11 +75,6 @@ def _create_test_document_chunk(
if title is not None and title_vector is None:
title_vector = [0.2] * 128
now = datetime.now(timezone.utc)
# We only store millisecond precision, so to make sure asserts work in this
# test file manually lose some precision from datetime.now().
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
return DocumentChunk(
document_id=document_id,
chunk_index=chunk_index,
@@ -71,11 +82,13 @@ def _create_test_document_chunk(
title_vector=title_vector,
content=content,
content_vector=content_vector,
source_type="test_source",
source_type=source_type.value,
metadata_list=None,
last_updated=now,
public=public,
access_control_list=[],
last_updated=last_updated,
public=document_access.is_public,
access_control_list=generate_opensearch_filtered_access_control_list(
document_access
),
hidden=hidden,
global_boost=0,
semantic_identifier="Test semantic identifier",
@@ -331,6 +344,9 @@ class TestOpenSearchClient:
chunk_index=0,
content="Content to retrieve",
tenant_state=tenant_state,
# We only store second precision, so to make sure asserts work in
# this test we'll deliberately lose some precision.
last_updated=datetime.now(timezone.utc).replace(microsecond=0),
)
test_client.index_document(document=original_doc)
@@ -471,6 +487,8 @@ class TestOpenSearchClient:
search_query = DocumentQuery.get_from_document_id_query(
document_id="delete-me",
tenant_state=tenant_state,
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -483,6 +501,8 @@ class TestOpenSearchClient:
keep_query = DocumentQuery.get_from_document_id_query(
document_id="keep-me",
tenant_state=tenant_state,
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -510,7 +530,6 @@ class TestOpenSearchClient:
chunk_index=0,
content="Original content",
tenant_state=tenant_state,
public=True,
hidden=False,
)
test_client.index_document(document=doc)
@@ -561,10 +580,13 @@ class TestOpenSearchClient:
properties_to_update={"hidden": True},
)
def test_search_basic(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
def test_hybrid_search_with_pipeline(
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests basic search functionality."""
"""Tests hybrid search with a normalization pipeline."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
@@ -574,24 +596,24 @@ class TestOpenSearchClient:
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index multiple documents with different content and vectors.
# Index documents.
docs = {
"search-doc-1": _create_test_document_chunk(
document_id="search-doc-1",
"doc-1": _create_test_document_chunk(
document_id="doc-1",
chunk_index=0,
content="Python programming language tutorial",
content_vector=_generate_test_vector(0.1),
tenant_state=tenant_state,
),
"search-doc-2": _create_test_document_chunk(
document_id="search-doc-2",
"doc-2": _create_test_document_chunk(
document_id="doc-2",
chunk_index=0,
content="How to make cheese",
content_vector=_generate_test_vector(0.2),
tenant_state=tenant_state,
),
"search-doc-3": _create_test_document_chunk(
document_id="search-doc-3",
"doc-3": _create_test_document_chunk(
document_id="doc-3",
chunk_index=0,
content="C++ for newborns",
content_vector=_generate_test_vector(0.15),
@@ -613,78 +635,10 @@ class TestOpenSearchClient:
num_candidates=10,
num_hits=5,
tenant_state=tenant_state,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
assert len(results) == 3
# Assert that all the chunks above are present.
assert all(
chunk.document_chunk.document_id
in ["search-doc-1", "search-doc-2", "search-doc-3"]
for chunk in results
)
# Make sure the chunk contents are preserved.
for chunk in results:
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight for the first hit. We
# don't expect highlights for any other hit.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
def test_search_with_pipeline(
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests search with a normalization pipeline."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents.
docs = {
"pipeline-doc-1": _create_test_document_chunk(
document_id="pipeline-doc-1",
chunk_index=0,
content="Machine learning algorithms for single-celled organisms",
content_vector=_generate_test_vector(0.3),
tenant_state=tenant_state,
),
"pipeline-doc-2": _create_test_document_chunk(
document_id="pipeline-doc-2",
chunk_index=0,
content="Deep learning shallow neural networks",
content_vector=_generate_test_vector(0.35),
tenant_state=tenant_state,
),
}
for doc in docs.values():
test_client.index_document(document=doc)
# Refresh index to make documents searchable
test_client.refresh_index()
# Search query.
query_text = "machine learning"
query_vector = _generate_test_vector(0.32)
search_body = DocumentQuery.get_hybrid_search_query(
query_text=query_text,
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
# Under test.
@@ -693,23 +647,26 @@ class TestOpenSearchClient:
)
# Postcondition.
assert len(results) == 2
assert len(results) == len(docs)
# Assert that all the chunks above are present.
assert all(
chunk.document_chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
for chunk in results
)
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
# Make sure the chunk contents are preserved.
for chunk in results:
for i, chunk in enumerate(results):
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight.
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
# Make sure there is some kind of match highlight only for the first
# result. The other results are so bad they're not expected to have
# match highlights.
if i == 0:
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
def test_search_empty_index(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests search on an empty index returns an empty list."""
# Precondition.
@@ -731,19 +688,28 @@ class TestOpenSearchClient:
num_candidates=10,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
# Postcondition.
assert len(results) == 0
def test_search_filters(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
def test_hybrid_search_with_pipeline_and_filters(
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests search filters for public/hidden documents and tenant isolation.
Tests search filters for ACL, hidden documents, and tenant isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
@@ -757,29 +723,47 @@ class TestOpenSearchClient:
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc-1": _create_test_document_chunk(
document_id="public-doc-1",
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
public=True,
hidden=False,
tenant_state=tenant_x,
),
"hidden-doc-1": _create_test_document_chunk(
document_id="hidden-doc-1",
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
public=True,
hidden=True,
tenant_state=tenant_x,
),
"private-doc-1": _create_test_document_chunk(
document_id="private-doc-1",
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
public=False,
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
@@ -787,7 +771,6 @@ class TestOpenSearchClient:
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
public=True,
hidden=False,
tenant_state=tenant_y,
),
@@ -798,9 +781,6 @@ class TestOpenSearchClient:
# Refresh index to make documents searchable.
test_client.refresh_index()
# Search with default filters (public=True, hidden=False).
# The DocumentQuery.get_hybrid_search_query uses filters that should
# only return public, non-hidden documents.
query_text = "document content"
query_vector = _generate_test_vector(0.6)
search_body = DocumentQuery.get_hybrid_search_query(
@@ -809,24 +789,41 @@ class TestOpenSearchClient:
num_candidates=10,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
# Postcondition.
# Should only get the public, non-hidden document.
assert len(results) == 1
assert results[0].document_chunk.document_id == "public-doc-1"
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# NOTE: This test is not explicitly testing for how well results are
# ordered; we're just assuming which doc will be the first result here.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == docs["public-doc-1"]
assert results[0].document_chunk == docs["public-doc"]
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
# Make sure there is some kind of match highlight.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == docs["private-doc-user-a"]
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
self,
test_client: OpenSearchClient,
search_pipeline: None,
@@ -849,52 +846,54 @@ class TestOpenSearchClient:
# Vectors closer to query_vector (0.1) should rank higher.
docs = [
_create_test_document_chunk(
document_id="highly-relevant-1",
document_id="highly-relevant",
chunk_index=0,
content="Artificial intelligence and machine learning transform technology",
content_vector=_generate_test_vector(
0.1
), # Very close to query vector.
public=True,
hidden=False,
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="somewhat-relevant-1",
document_id="somewhat-relevant",
chunk_index=0,
content="Computer programming with various languages",
content_vector=_generate_test_vector(0.5), # Far from query vector.
public=True,
hidden=False,
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="not-very-relevant-1",
document_id="not-very-relevant",
chunk_index=0,
content="Cooking recipes for delicious meals",
content_vector=_generate_test_vector(
0.9
), # Very far from query vector.
public=True,
hidden=False,
tenant_state=tenant_x,
),
# These should be filtered out by public/hidden filters.
_create_test_document_chunk(
document_id="hidden-but-relevant-1",
document_id="hidden-but-relevant",
chunk_index=0,
content="Artificial intelligence research papers",
content_vector=_generate_test_vector(0.05), # Very close but hidden.
public=True,
hidden=True,
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="private-but-relevant-1",
document_id="private-but-relevant",
chunk_index=0,
content="Artificial intelligence industry analysis",
content_vector=_generate_test_vector(0.08), # Very close but private.
public=False,
document_access=DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
hidden=False,
tenant_state=tenant_x,
),
@@ -905,7 +904,7 @@ class TestOpenSearchClient:
# Refresh index to make documents searchable.
test_client.refresh_index()
# Search query matching "highly-relevant-1" most closely.
# Search query matching "highly-relevant" most closely.
query_text = "artificial intelligence"
query_vector = _generate_test_vector(0.1)
search_body = DocumentQuery.get_hybrid_search_query(
@@ -914,6 +913,9 @@ class TestOpenSearchClient:
num_candidates=10,
num_hits=5,
tenant_state=tenant_x,
# Explicitly pass in an empty list to enforce private doc filtering.
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
include_hidden=False,
)
# Under test.
@@ -925,15 +927,15 @@ class TestOpenSearchClient:
# Should only get public, non-hidden documents (3 out of 5).
assert len(results) == 3
result_ids = [chunk.document_chunk.document_id for chunk in results]
assert "highly-relevant-1" in result_ids
assert "somewhat-relevant-1" in result_ids
assert "not-very-relevant-1" in result_ids
assert "highly-relevant" in result_ids
assert "somewhat-relevant" in result_ids
assert "not-very-relevant" in result_ids
# Filtered out by public/hidden constraints.
assert "hidden-but-relevant-1" not in result_ids
assert "private-but-relevant-1" not in result_ids
assert "hidden-but-relevant" not in result_ids
assert "private-but-relevant" not in result_ids
# Most relevant document should be first due to normalization pipeline.
assert results[0].document_chunk.document_id == "highly-relevant-1"
# Most relevant document should be first.
assert results[0].document_chunk.document_id == "highly-relevant"
# Make sure there is some kind of match highlight for the most relevant
# result.
@@ -1014,6 +1016,8 @@ class TestOpenSearchClient:
verify_query_x = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-x",
tenant_state=tenant_x,
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -1026,6 +1030,8 @@ class TestOpenSearchClient:
verify_query_y = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-y",
tenant_state=tenant_y,
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -1113,6 +1119,8 @@ class TestOpenSearchClient:
query_body = DocumentQuery.get_from_document_id_query(
document_id="doc-1",
tenant_state=tenant_state,
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -1133,3 +1141,176 @@ class TestOpenSearchClient:
for chunk in doc1_chunks
}
assert set(chunk_ids) == expected_ids
def test_search_with_no_document_access_can_retrieve_all_documents(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests search with no document access can retrieve all documents, even
private ones.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_state,
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_state,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_state,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
}
for doc in docs.values():
test_client.index_document(document=doc)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Build query for all documents.
query_body = DocumentQuery.get_from_document_id_query(
document_id="private-doc-user-a",
tenant_state=tenant_state,
# This is the input under test, notice None for acl.
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
)
# Under test.
chunk_ids = test_client.search_for_document_ids(body=query_body)
# Postcondition.
# Even though this doc is private, because we supplied None for acl we
# were able to retrieve it.
assert len(chunk_ids) == 1
# Since this is a chunk ID, it will have the doc ID in it plus other
# stuff we don't care about in this test.
assert chunk_ids[0].startswith("private-doc-user-a")
def test_time_cutoff_filter(
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests the time cutoff filter works."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index docs with various ages.
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
one_week_ago = datetime.now(timezone.utc) - timedelta(days=7)
six_months_ago = datetime.now(timezone.utc) - timedelta(days=180)
one_year_ago = datetime.now(timezone.utc) - timedelta(days=365)
docs = [
_create_test_document_chunk(
document_id="one-day-ago",
content="Good match",
last_updated=one_day_ago,
tenant_state=tenant_state,
),
_create_test_document_chunk(
document_id="one-year-ago",
content="Good match",
last_updated=one_year_ago,
tenant_state=tenant_state,
),
_create_test_document_chunk(
document_id="no-last-updated",
# Since we test for result ordering in the postconditions, let's
# just make this content slightly less of a match with the query
# so this test is not flaky from the ordering of the results.
content="Still an ok match",
last_updated=None,
tenant_state=tenant_state,
),
]
for doc in docs:
test_client.index_document(document=doc)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Build query for documents updated in the last week.
last_week_search_body = DocumentQuery.get_hybrid_search_query(
query_text="Good match",
query_vector=_generate_test_vector(0.1),
num_candidates=10,
num_hits=5,
tenant_state=tenant_state,
index_filters=IndexFilters(
access_control_list=None, tenant_id=None, time_cutoff=one_week_ago
),
include_hidden=False,
)
last_six_months_search_body = DocumentQuery.get_hybrid_search_query(
query_text="Good match",
query_vector=_generate_test_vector(0.1),
num_candidates=10,
num_hits=5,
tenant_state=tenant_state,
index_filters=IndexFilters(
access_control_list=None, tenant_id=None, time_cutoff=six_months_ago
),
include_hidden=False,
)
# Under test.
last_week_results = test_client.search(
body=last_week_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
)
last_six_months_results = test_client.search(
body=last_six_months_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
)
# Postcondition.
# We expect to only get one-day-ago.
assert len(last_week_results) == 1
assert last_week_results[0].document_chunk.document_id == "one-day-ago"
# We expect to get one-day-ago and no-last-updated since six months >
# ASSUMED_DOCUMENT_AGE_DAYS.
assert len(last_six_months_results) == 2
assert last_six_months_results[0].document_chunk.document_id == "one-day-ago"
assert (
last_six_months_results[1].document_chunk.document_id == "no-last-updated"
)

View File

@@ -476,8 +476,8 @@ class ChatSessionManager:
else GENERAL_HEADERS
),
)
# Chat session should return 400 if it doesn't exist
return response.status_code == 400
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@staticmethod
def verify_soft_deleted(

View File

@@ -31,7 +31,7 @@ class ProjectManager:
) -> List[UserProjectSnapshot]:
"""Get all projects for a user via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/",
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
@@ -56,7 +56,7 @@ class ProjectManager:
) -> bool:
"""Verify that a project has been deleted by ensuring it's not in list."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/",
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()

View File

@@ -0,0 +1,185 @@
from uuid import uuid4
import pytest
import requests
from requests import HTTPError
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
@pytest.fixture(scope="module", autouse=True)
def reset_for_module() -> None:
"""Reset all data once before running any tests in this module."""
reset_all()
@pytest.fixture
def second_user(admin_user: DATestUser) -> DATestUser:
# Ensure admin exists so this new user is created with BASIC role.
try:
return UserManager.create(name="second_basic_user")
except HTTPError as e:
response = e.response
if response is None:
raise
if response.status_code not in (400, 409):
raise
try:
payload = response.json()
except ValueError:
raise
detail = payload.get("detail")
if not _is_user_already_exists_detail(detail):
raise
print("Second basic user already exists; logging in instead.")
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("second_basic_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.BASIC,
is_active=True,
)
)
def _is_user_already_exists_detail(detail: object) -> bool:
if isinstance(detail, str):
normalized = detail.lower()
return (
"already exists" in normalized
or "register_user_already_exists" in normalized
)
if isinstance(detail, dict):
code = detail.get("code")
if isinstance(code, str) and code.lower() == "register_user_already_exists":
return True
message = detail.get("message")
if isinstance(message, str) and "already exists" in message.lower():
return True
return False
def _get_chat_session(
chat_session_id: str,
user: DATestUser,
is_shared: bool | None = None,
include_deleted: bool | None = None,
) -> requests.Response:
params: dict[str, str] = {}
if is_shared is not None:
params["is_shared"] = str(is_shared).lower()
if include_deleted is not None:
params["include_deleted"] = str(include_deleted).lower()
return requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
params=params,
headers=user.headers,
cookies=user.cookies,
)
def _set_sharing_status(
chat_session_id: str, sharing_status: str, user: DATestUser
) -> requests.Response:
return requests.patch(
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
json={"sharing_status": sharing_status},
headers=user.headers,
cookies=user.cookies,
)
def test_private_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify private sessions are only accessible by the owner and never via share link."""
# Create a private chat session owned by basic_user.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
# Owner can access the private session normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Share link should be forbidden when the session is private.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 403
# Other users cannot access private sessions directly.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Other users also cannot access private sessions via share link.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 403
def test_public_shared_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify shared sessions are accessible only via share link for non-owners."""
# Create a private session, then mark it public.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
assert response.status_code == 200
# Owner can access normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Owner can also access via share link.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 200
# Non-owner cannot access without share link.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Non-owner can access with share link for public sessions.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 200
def test_deleted_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
# Create and soft-delete a session.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
deletion_success = ChatSessionManager.soft_delete(
chat_session=chat_session, user_performing_action=basic_user
)
assert deletion_success is True
# Deleted sessions are not accessible normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 404
# Owner can fetch deleted session only with include_deleted.
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
assert response.status_code == 200
assert response.json().get("deleted") is True
# Non-owner should be blocked even with include_deleted.
response = _get_chat_session(
str(chat_session.id), second_user, include_deleted=True
)
assert response.status_code == 403
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
"""Verify unknown IDs return 404."""
response = _get_chat_session(str(uuid4()), basic_user)
assert response.status_code == 404

View File

@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
assert fallback_llm.config.model_name == default_provider.default_model_name
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Test that the /llm/provider endpoint correctly excludes non-public providers
with no group/persona restrictions.
This tests the fix for the bug where non-public providers with no restrictions
were incorrectly shown to all users instead of being admin-only.
"""
admin_user, basic_user = users
# Create a public provider (should be visible to all)
public_provider = LLMProviderManager.create(
name="public-provider",
is_public=True,
set_as_default=True,
user_performing_action=admin_user,
)
# Create a non-public provider with no restrictions (should be admin-only)
non_public_provider = LLMProviderManager.create(
name="non-public-unrestricted",
is_public=False,
groups=[],
personas=[],
set_as_default=False,
user_performing_action=admin_user,
)
# Non-admin user calls the /llm/provider endpoint
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
assert public_provider.name in provider_names
# Non-public provider with no restrictions should NOT be visible to non-admin
assert non_public_provider.name not in provider_names
# Admin user should see both providers
admin_response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
assert non_public_provider.name in admin_provider_names
def test_provider_delete_clears_persona_references(reset: None) -> None:
"""Test that deleting a provider automatically clears persona references."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -0,0 +1,65 @@
from uuid import uuid4
import requests
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.persona import PersonaLabelManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
def test_update_persona_with_null_label_ids_preserves_labels(
reset: None, admin_user: DATestUser
) -> None:
persona_label = PersonaLabelManager.create(
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
user_performing_action=admin_user,
)
assert persona_label.id is not None
persona = PersonaManager.create(
label_ids=[persona_label.id],
user_performing_action=admin_user,
)
updated_description = f"{persona.description}-updated"
update_request = PersonaUpsertRequest(
name=persona.name,
description=updated_description,
system_prompt=persona.system_prompt or "",
task_prompt=persona.task_prompt or "",
datetime_aware=persona.datetime_aware,
document_set_ids=persona.document_set_ids,
num_chunks=persona.num_chunks,
is_public=persona.is_public,
recency_bias=persona.recency_bias,
llm_filter_extraction=persona.llm_filter_extraction,
llm_relevance_filter=persona.llm_relevance_filter,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
tool_ids=persona.tool_ids,
users=[],
groups=[],
label_ids=None,
)
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=update_request.model_dump(mode="json", exclude_none=False),
headers=admin_user.headers,
cookies=admin_user.cookies,
)
response.raise_for_status()
fetched = requests.get(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=admin_user.headers,
cookies=admin_user.cookies,
)
fetched.raise_for_status()
fetched_persona = fetched.json()
assert fetched_persona["description"] == updated_description
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
assert persona_label.id in fetched_label_ids

View File

@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
provider_id = _activate_exa_provider(admin_user)
assert isinstance(provider_id, int)
search_request = {"queries": ["latest ai research news"], "max_results": 3}
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
lite_response = requests.post(
f"{API_SERVER_URL}/web-search/search-lite",

View File

@@ -0,0 +1,50 @@
"""Tests for Asana connector configuration parsing."""
import pytest
from onyx.connectors.asana.connector import AsanaConnector
@pytest.mark.parametrize(
"project_ids,expected",
[
(None, None),
("", None),
(" ", None),
(" 123 ", ["123"]),
(" 123 , , 456 , ", ["123", "456"]),
],
)
def test_asana_connector_project_ids_normalization(
project_ids: str | None, expected: list[str] | None
) -> None:
connector = AsanaConnector(
asana_workspace_id=" 1153293530468850 ",
asana_project_ids=project_ids,
asana_team_id=" 1210918501948021 ",
)
assert connector.workspace_id == "1153293530468850"
assert connector.project_ids_to_index == expected
assert connector.asana_team_id == "1210918501948021"
@pytest.mark.parametrize(
"team_id,expected",
[
(None, None),
("", None),
(" ", None),
(" 1210918501948021 ", "1210918501948021"),
],
)
def test_asana_connector_team_id_normalization(
team_id: str | None, expected: str | None
) -> None:
connector = AsanaConnector(
asana_workspace_id="1153293530468850",
asana_project_ids=None,
asana_team_id=team_id,
)
assert connector.asana_team_id == expected

View File

@@ -0,0 +1,506 @@
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
class TestConvertToMetadataValue:
"""Tests for the _convert_to_metadata_value helper function."""
def test_string_value(self) -> None:
"""String values should be returned as-is."""
assert _convert_to_metadata_value("hello") == "hello"
assert _convert_to_metadata_value("") == ""
def test_boolean_true(self) -> None:
"""Boolean True should be converted to string 'True'."""
assert _convert_to_metadata_value(True) == "True"
def test_boolean_false(self) -> None:
"""Boolean False should be converted to string 'False'."""
assert _convert_to_metadata_value(False) == "False"
def test_integer_value(self) -> None:
"""Integer values should be converted to string."""
assert _convert_to_metadata_value(42) == "42"
assert _convert_to_metadata_value(0) == "0"
assert _convert_to_metadata_value(-100) == "-100"
def test_float_value(self) -> None:
"""Float values should be converted to string."""
assert _convert_to_metadata_value(3.14) == "3.14"
assert _convert_to_metadata_value(0.0) == "0.0"
assert _convert_to_metadata_value(-2.5) == "-2.5"
def test_list_of_strings(self) -> None:
"""List of strings should remain as list of strings."""
result = _convert_to_metadata_value(["a", "b", "c"])
assert result == ["a", "b", "c"]
def test_list_of_mixed_types(self) -> None:
"""List with mixed types should have all items converted to strings."""
result = _convert_to_metadata_value([1, True, 3.14, "text"])
assert result == ["1", "True", "3.14", "text"]
def test_empty_list(self) -> None:
"""Empty list should return empty list."""
assert _convert_to_metadata_value([]) == []
class TestYieldDocBatches:
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
@pytest.fixture
def connector(self) -> SalesforceConnector:
"""Create a SalesforceConnector instance with mocked sf_client."""
connector = SalesforceConnector(
batch_size=10,
requested_objects=["Opportunity"],
)
# Mock the sf_client property
mock_sf_client = MagicMock()
mock_sf_client.sf_instance = "test.salesforce.com"
connector._sf_client = mock_sf_client
return connector
@pytest.fixture
def mock_sf_db(self) -> MagicMock:
"""Create a mock OnyxSalesforceSQLite object."""
return MagicMock()
def _create_salesforce_object(
self,
object_id: str,
object_type: str,
data: dict[str, Any],
) -> SalesforceObject:
"""Helper to create a SalesforceObject with required fields."""
# Ensure required fields are present
data.setdefault(ID_FIELD, object_id)
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
data.setdefault(NAME_FIELD, f"Test {object_type}")
return SalesforceObject(id=object_id, type=object_type, data=data)
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_metadata_type_conversion_for_opportunity(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that Opportunity metadata fields are properly type-converted."""
parent_id = "006bm000006kyDpAAI"
parent_type = "Opportunity"
# Create a parent object with various data types in the fields
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Test Opportunity",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"Account": "Acme Corp", # string - should become "account" metadata
"FiscalQuarter": 2, # int - should be converted to "2"
"FiscalYear": 2024, # int - should be converted to "2024"
"IsClosed": False, # bool - should be converted to "False"
"StageName": "Prospecting", # string
"Type": "New Business", # string
"Amount": 50000.50, # float - should be converted to "50000.50"
"CloseDate": "2024-06-30", # string
"Probability": 75, # int - should be converted to "75"
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
# Setup mock sf_db
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
# Create a mock document that convert_sf_object_to_doc will return
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Test Opportunity",
metadata={},
)
mock_convert.return_value = mock_doc
# Track parent changes
parents_changed = 0
def increment() -> None:
nonlocal parents_changed
parents_changed += 1
# Call _yield_doc_batches
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
increment,
)
)
# Verify we got one batch with one document
assert len(batches) == 1
docs = batches[0]
assert len(docs) == 1
doc = docs[0]
assert isinstance(doc, Document)
# Verify metadata type conversions
# All values should be strings (or list of strings)
assert doc.metadata["object_type"] == "Opportunity"
assert doc.metadata["account"] == "Acme Corp" # string stays string
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
assert doc.metadata["fiscal_year"] == "2024" # int -> str
assert doc.metadata["is_closed"] == "False" # bool -> str
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
assert doc.metadata["type"] == "New Business" # string stays string
assert (
doc.metadata["amount"] == "50000.5"
) # float -> str (Python drops trailing zeros)
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
assert doc.metadata["probability"] == "75" # int -> str
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
# Verify parent was counted
assert parents_changed == 1
assert type_to_processed[parent_type] == 1
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_missing_optional_metadata_fields(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that missing optional metadata fields are not added."""
parent_id = "006bm000006kyDqAAI"
parent_type = "Opportunity"
# Create parent object with only some fields
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Minimal Opportunity",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"StageName": "Closed Won",
# Notably missing: Amount, Probability, FiscalQuarter, etc.
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Minimal Opportunity",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Only present fields should be in metadata
assert "stage_name" in doc.metadata
assert doc.metadata["stage_name"] == "Closed Won"
assert "name" in doc.metadata
assert doc.metadata["name"] == "Minimal Opportunity"
# Missing fields should not be in metadata
assert "amount" not in doc.metadata
assert "probability" not in doc.metadata
assert "fiscal_quarter" not in doc.metadata
assert "fiscal_year" not in doc.metadata
assert "is_closed" not in doc.metadata
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_contact_metadata_fields(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test metadata conversion for Contact object type."""
parent_id = "003bm00000EjHCjAAN"
parent_type = "Contact"
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "John Doe",
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
"Account": "Globex Corp",
"CreatedDate": "2024-01-01T00:00:00.000Z",
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="John Doe",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Verify Contact-specific metadata
assert doc.metadata["object_type"] == "Contact"
assert doc.metadata["account"] == "Globex Corp"
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_no_default_attributes_for_unknown_type(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that unknown object types only get object_type metadata."""
parent_id = "001bm00000fd9Z3AAI"
parent_type = "CustomObject__c"
parent_data = {
ID_FIELD: parent_id,
NAME_FIELD: "Custom Record",
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
"CustomField__c": "custom value",
"NumberField__c": 123,
}
parent_object = self._create_salesforce_object(
parent_id, parent_type, parent_data
)
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = parent_object
mock_sf_db.file_size = 1024
mock_doc = Document(
id=f"SALESFORCE_{parent_id}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier="Custom Record",
metadata={},
)
mock_convert.return_value = mock_doc
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
doc = batches[0][0]
assert isinstance(doc, Document)
# Only object_type should be set for unknown types
assert doc.metadata["object_type"] == "CustomObject__c"
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
assert "CustomField__c" not in doc.metadata
assert "NumberField__c" not in doc.metadata
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_skips_missing_parent_objects(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that missing parent objects are skipped gracefully."""
parent_id = "006bm000006kyDrAAI"
parent_type = "Opportunity"
# get_record returns None for missing object
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, parent_id, 1)]
)
mock_sf_db.get_record.return_value = None
mock_sf_db.file_size = 1024
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {parent_id: parent_type}
parent_types = {parent_type}
parents_changed = 0
def increment() -> None:
nonlocal parents_changed
parents_changed += 1
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
increment,
)
)
# Should yield one empty batch
assert len(batches) == 1
assert len(batches[0]) == 0
# convert_sf_object_to_doc should not have been called
mock_convert.assert_not_called()
# Parents changed should still be 0
assert parents_changed == 0
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
def test_multiple_documents_batching(
self,
mock_convert: MagicMock,
connector: SalesforceConnector,
mock_sf_db: MagicMock,
) -> None:
"""Test that multiple documents are correctly batched."""
# Create 3 parent objects
parent_ids = [
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
]
parent_type = "Opportunity"
parent_objects = [
self._create_salesforce_object(
pid,
parent_type,
{
ID_FIELD: pid,
NAME_FIELD: f"Opportunity {i}",
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
"IsClosed": i % 2 == 0, # alternating bool values
"Amount": 1000.0 * (i + 1),
},
)
for i, pid in enumerate(parent_ids)
]
# Setup mock to return all three
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
)
mock_sf_db.get_record.side_effect = parent_objects
mock_sf_db.file_size = 1024
# Create mock documents
mock_docs = [
Document(
id=f"SALESFORCE_{pid}",
sections=[],
source=DocumentSource.SALESFORCE,
semantic_identifier=f"Opportunity {i}",
metadata={},
)
for i, pid in enumerate(parent_ids)
]
mock_convert.side_effect = mock_docs
type_to_processed: dict[str, int] = {}
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
parent_types = {parent_type}
batches = list(
connector._yield_doc_batches(
mock_sf_db,
type_to_processed,
changed_ids_to_type,
parent_types,
lambda: None,
)
)
# With batch_size=10, all 3 docs should be in one batch
assert len(batches) == 1
assert len(batches[0]) == 3
# Verify each document has correct metadata
for i, doc in enumerate(batches[0]):
assert isinstance(doc, Document)
assert doc.metadata["object_type"] == "Opportunity"
assert doc.metadata["is_closed"] == str(i % 2 == 0)
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
assert type_to_processed[parent_type] == 3

View File

@@ -0,0 +1,135 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User__UserGroup
from onyx.db.users import delete_user_from_db
def _mock_user(
user_id: UUID | None = None, email: str = "test@example.com"
) -> MagicMock:
user = MagicMock()
user.id = user_id or uuid4()
user.email = email
user.oauth_accounts = []
return user
def _make_query_chain() -> MagicMock:
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
chain = MagicMock()
chain.filter.return_value = chain
return chain
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_nulls_out_document_set_ownership(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
db_session = MagicMock()
query_chains: dict[type, MagicMock] = {}
def query_side_effect(model: type) -> MagicMock:
if model not in query_chains:
query_chains[model] = _make_query_chain()
return query_chains[model]
db_session.query.side_effect = query_side_effect
delete_user_from_db(user, db_session)
# Verify DocumentSet.user_id is nulled out (update, not delete)
doc_set_chain = query_chains[DocumentSet]
doc_set_chain.filter.assert_called()
doc_set_chain.filter.return_value.update.assert_called_once_with(
{DocumentSet.user_id: None}
)
# Verify Persona.user_id is nulled out (update, not delete)
persona_chain = query_chains[Persona]
persona_chain.filter.assert_called()
persona_chain.filter.return_value.update.assert_called_once_with(
{Persona.user_id: None}
)
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_cleans_up_join_tables(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
db_session = MagicMock()
query_chains: dict[type, MagicMock] = {}
def query_side_effect(model: type) -> MagicMock:
if model not in query_chains:
query_chains[model] = _make_query_chain()
return query_chains[model]
db_session.query.side_effect = query_side_effect
delete_user_from_db(user, db_session)
# Join tables should be deleted (not updated)
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
chain = query_chains[model]
chain.filter.return_value.delete.assert_called_once()
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_commits_and_removes_invited(
_mock_ee: Any, mock_remove_invited: Any
) -> None:
user = _mock_user(email="deleted@example.com")
db_session = MagicMock()
db_session.query.return_value = _make_query_chain()
delete_user_from_db(user, db_session)
db_session.delete.assert_called_once_with(user)
db_session.commit.assert_called_once()
mock_remove_invited.assert_called_once_with("deleted@example.com")
@patch("onyx.db.users.remove_user_from_invited_users")
@patch(
"onyx.db.users.fetch_ee_implementation_or_noop",
return_value=lambda **_kwargs: None,
)
def test_delete_user_deletes_oauth_accounts(
_mock_ee: Any, _mock_remove_invited: Any
) -> None:
user = _mock_user()
oauth1 = MagicMock()
oauth2 = MagicMock()
user.oauth_accounts = [oauth1, oauth2]
db_session = MagicMock()
db_session.query.return_value = _make_query_chain()
delete_user_from_db(user, db_session)
db_session.delete.assert_any_call(oauth1)
db_session.delete.assert_any_call(oauth2)

View File

@@ -0,0 +1,106 @@
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
from onyx.onyxbot.slack.formatting import _sanitize_html
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered
def test_slack_style_links_converted_to_clickable_links() -> None:
message = "Visit <https://example.com/page|Example Page> for details."
formatted = format_slack_message(message)
assert "<https://example.com/page|Example Page>" in formatted
assert "&lt;" not in formatted
def test_slack_style_links_preserved_inside_code_blocks() -> None:
message = "```\n<https://example.com|click>\n```"
converted = _convert_slack_links_to_markdown(message)
assert "<https://example.com|click>" in converted
def test_html_tags_stripped_outside_code_blocks() -> None:
message = "Hello<br/>world ```<div>code</div>``` after"
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
assert "<br" not in sanitized
assert "<div>code</div>" in sanitized
def test_format_slack_message_block_spacing() -> None:
message = "Paragraph one.\n\nParagraph two."
formatted = format_slack_message(message)
assert "Paragraph one.\n\nParagraph two." == formatted
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
message = "```python\nprint('hi')\n```"
formatted = format_slack_message(message)
assert formatted.endswith("print('hi')\n```")
def test_format_slack_message_ampersand_not_double_escaped() -> None:
message = 'She said "hello" & goodbye.'
formatted = format_slack_message(message)
assert "&amp;" in formatted
assert "&quot;" not in formatted

View File

@@ -0,0 +1,57 @@
from typing import Any
from unittest.mock import Mock
from onyx.configs.constants import MilestoneRecordType
from onyx.utils import telemetry as telemetry_utils
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
fetch_impl = Mock()
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_not_called()
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
monkeypatch: Any,
) -> None:
event_telemetry = Mock()
fetch_impl = Mock(return_value=event_telemetry)
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_called_once_with(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=telemetry_utils.noop_fallback,
)
event_telemetry.assert_called_once_with(
"user@example.com",
MilestoneRecordType.USER_MESSAGE_SENT,
{"origin": "web", "tenant_id": "tenant-1"},
)

View File

@@ -221,6 +221,13 @@ services:
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
# Logging
# Leave this on pretty please? Nothing sensitive is collected!
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}

View File

@@ -63,6 +63,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
env_file:
- path: .env
required: false

View File

@@ -82,6 +82,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
env_file:
- path: .env
required: false

View File

@@ -129,6 +129,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro

View File

@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
## CORS origins for MCP clients (comma-separated list)
# MCP_SERVER_CORS_ORIGINS=
## Discord Bot Configuration
## The Discord bot allows users to interact with Onyx from Discord servers
## Bot token from Discord Developer Portal (required to enable the bot)
# DISCORD_BOT_TOKEN=
## Command prefix for bot commands (default: "!")
# DISCORD_BOT_INVOKE_CHAR=!
## Celery Configuration
# CELERY_BROKER_POOL_LIMIT=
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=

View File

@@ -582,29 +582,33 @@ else
fi
# Ask for authentication schema
echo ""
print_info "Which authentication schema would you like to set up?"
echo ""
echo "1) Basic - Username/password authentication"
echo "2) No Auth - Open access (development/testing)"
echo ""
read -p "Choose an option (1-2) [default 1]: " -r AUTH_CHOICE
echo ""
# echo ""
# print_info "Which authentication schema would you like to set up?"
# echo ""
# echo "1) Basic - Username/password authentication"
# echo "2) No Auth - Open access (development/testing)"
# echo ""
# read -p "Choose an option (1) [default 1]: " -r AUTH_CHOICE
# echo ""
case "${AUTH_CHOICE:-1}" in
1)
AUTH_SCHEMA="basic"
print_info "Selected: Basic authentication"
;;
2)
AUTH_SCHEMA="disabled"
print_info "Selected: No authentication"
;;
*)
AUTH_SCHEMA="basic"
print_info "Invalid choice, using basic authentication"
;;
esac
# case "${AUTH_CHOICE:-1}" in
# 1)
# AUTH_SCHEMA="basic"
# print_info "Selected: Basic authentication"
# ;;
# # 2)
# # AUTH_SCHEMA="disabled"
# # print_info "Selected: No authentication"
# # ;;
# *)
# AUTH_SCHEMA="basic"
# print_info "Invalid choice, using basic authentication"
# ;;
# esac
# TODO (jessica): Uncomment this once no auth users still have an account
# Use basic auth by default
AUTH_SCHEMA="basic"
# Create .env file from template
print_info "Creating .env file with your selections..."

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.19
version: 0.4.20
appVersion: latest
annotations:
category: Productivity

View File

@@ -0,0 +1,98 @@
{{- if .Values.discordbot.enabled }}
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "onyx.fullname" . }}-discordbot
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.discordbot.deploymentLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
replicas: 1
strategy:
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
selector:
matchLabels:
{{- include "onyx.selectorLabels" . | nindent 6 }}
{{- if .Values.discordbot.deploymentLabels }}
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
{{- end }}
template:
metadata:
annotations:
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
{{- with .Values.discordbot.podAnnotations }}
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "onyx.labels" . | nindent 8 }}
{{- with .Values.discordbot.deploymentLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
{{- with .Values.discordbot.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: discordbot
securityContext:
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
imagePullPolicy: {{ .Values.global.pullPolicy }}
command: ["python", "onyx/onyxbot/discord/client.py"]
resources:
{{- toYaml .Values.discordbot.resources | nindent 12 }}
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
{{- include "onyx.envSecrets" . | nindent 12}}
# Discord bot token - required for bot to connect
{{- if .Values.discordbot.botToken }}
- name: DISCORD_BOT_TOKEN
value: {{ .Values.discordbot.botToken | quote }}
{{- end }}
{{- if .Values.discordbot.botTokenSecretName }}
- name: DISCORD_BOT_TOKEN
valueFrom:
secretKeyRef:
name: {{ .Values.discordbot.botTokenSecretName }}
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
{{- end }}
# Command prefix for bot commands (default: "!")
{{- if .Values.discordbot.invokeChar }}
- name: DISCORD_BOT_INVOKE_CHAR
value: {{ .Values.discordbot.invokeChar | quote }}
{{- end }}
{{- with .Values.discordbot.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.discordbot.volumes }}
volumes:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}

View File

@@ -701,6 +701,44 @@ celery_worker_user_file_processing:
tolerations: []
affinity: {}
# Discord bot for Onyx
# The bot offloads message processing to scalable API pods via HTTP requests.
discordbot:
enabled: false # Disabled by default - requires bot token configuration
# Bot token can be provided directly or via a Kubernetes secret
# Option 1: Direct token (not recommended for production)
botToken: ""
# Option 2: Reference a Kubernetes secret (recommended)
botTokenSecretName: "" # Name of the secret containing the bot token
botTokenSecretKey: "token" # Key within the secret (default: "token")
# Command prefix for bot commands (default: "!")
invokeChar: "!"
image:
repository: onyxdotapp/onyx-backend
tag: "" # Overrides the image tag whose default is the chart appVersion.
podAnnotations: {}
podLabels:
scope: onyx-backend
app: discord-bot
deploymentLabels:
app: discord-bot
podSecurityContext:
{}
securityContext:
{}
resources:
requests:
cpu: "500m"
memory: "512Mi"
limits:
cpu: "1000m"
memory: "2000Mi"
volumes: []
volumeMounts: []
nodeSelector: {}
tolerations: []
affinity: {}
slackbot:
enabled: true
replicaCount: 1
@@ -1159,6 +1197,8 @@ configMap:
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
NOTIFY_SLACKBOT_NO_ANSWER: ""
DISCORD_BOT_TOKEN: ""
DISCORD_BOT_INVOKE_CHAR: ""
# Logging
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
DISABLE_TELEMETRY: ""

3
desktop/.gitignore vendored
View File

@@ -22,3 +22,6 @@ npm-debug.log*
# Local env files
.env
.env.local
# Generated files
src-tauri/gen/schemas/acl-manifests.json

View File

@@ -706,16 +706,6 @@ dependencies = [
"typeid",
]
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
]
[[package]]
name = "fdeflate"
version = "0.3.7"
@@ -993,16 +983,6 @@ dependencies = [
"version_check",
]
[[package]]
name = "gethostname"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
dependencies = [
"rustix",
"windows-link 0.2.1",
]
[[package]]
name = "getrandom"
version = "0.1.16"
@@ -1122,24 +1102,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "global-hotkey"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
dependencies = [
"crossbeam-channel",
"keyboard-types",
"objc2 0.6.3",
"objc2-app-kit 0.3.2",
"once_cell",
"serde",
"thiserror 2.0.17",
"windows-sys 0.59.0",
"x11rb",
"xkeysym",
]
[[package]]
name = "gobject-sys"
version = "0.18.0"
@@ -1713,12 +1675,6 @@ dependencies = [
"libc",
]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "litemap"
version = "0.8.1"
@@ -2248,7 +2204,6 @@ dependencies = [
"serde_json",
"tauri",
"tauri-build",
"tauri-plugin-global-shortcut",
"tauri-plugin-shell",
"tauri-plugin-window-state",
"tokio",
@@ -2878,19 +2833,6 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustversion"
version = "1.0.22"
@@ -3605,21 +3547,6 @@ dependencies = [
"walkdir",
]
[[package]]
name = "tauri-plugin-global-shortcut"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
dependencies = [
"global-hotkey",
"log",
"serde",
"serde_json",
"tauri",
"tauri-plugin",
"thiserror 2.0.17",
]
[[package]]
name = "tauri-plugin-shell"
version = "2.3.3"
@@ -5021,29 +4948,6 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "x11rb"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
dependencies = [
"gethostname",
"rustix",
"x11rb-protocol",
]
[[package]]
name = "x11rb-protocol"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
[[package]]
name = "xkeysym"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
[[package]]
name = "yoke"
version = "0.8.1"

View File

@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
[dependencies]
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
tauri-plugin-shell = "2.0"
tauri-plugin-global-shortcut = "2.0"
tauri-plugin-window-state = "2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

File diff suppressed because one or more lines are too long

View File

@@ -2354,72 +2354,6 @@
"const": "core:window:deny-unminimize",
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
},
{
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
"type": "string",
"const": "global-shortcut:default",
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
},
{
"description": "Enables the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-is-registered",
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
},
{
"description": "Enables the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register",
"markdownDescription": "Enables the register command without any pre-configured scope."
},
{
"description": "Enables the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register-all",
"markdownDescription": "Enables the register_all command without any pre-configured scope."
},
{
"description": "Enables the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister",
"markdownDescription": "Enables the unregister command without any pre-configured scope."
},
{
"description": "Enables the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister-all",
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
},
{
"description": "Denies the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-is-registered",
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
},
{
"description": "Denies the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register",
"markdownDescription": "Denies the register command without any pre-configured scope."
},
{
"description": "Denies the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register-all",
"markdownDescription": "Denies the register_all command without any pre-configured scope."
},
{
"description": "Denies the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister",
"markdownDescription": "Denies the unregister command without any pre-configured scope."
},
{
"description": "Denies the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister-all",
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
},
{
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
"type": "string",

View File

@@ -2354,72 +2354,6 @@
"const": "core:window:deny-unminimize",
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
},
{
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
"type": "string",
"const": "global-shortcut:default",
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
},
{
"description": "Enables the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-is-registered",
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
},
{
"description": "Enables the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register",
"markdownDescription": "Enables the register command without any pre-configured scope."
},
{
"description": "Enables the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-register-all",
"markdownDescription": "Enables the register_all command without any pre-configured scope."
},
{
"description": "Enables the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister",
"markdownDescription": "Enables the unregister command without any pre-configured scope."
},
{
"description": "Enables the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:allow-unregister-all",
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
},
{
"description": "Denies the is_registered command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-is-registered",
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
},
{
"description": "Denies the register command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register",
"markdownDescription": "Denies the register command without any pre-configured scope."
},
{
"description": "Denies the register_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-register-all",
"markdownDescription": "Denies the register_all command without any pre-configured scope."
},
{
"description": "Denies the unregister command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister",
"markdownDescription": "Denies the unregister command without any pre-configured scope."
},
{
"description": "Denies the unregister_all command without any pre-configured scope.",
"type": "string",
"const": "global-shortcut:deny-unregister-all",
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
},
{
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
"type": "string",

View File

@@ -20,7 +20,6 @@ use tauri::Wry;
use tauri::{
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
};
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
use url::Url;
#[cfg(target_os = "macos")]
use tokio::time::sleep;
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
window.start_dragging().map_err(|e| e.to_string())
}
// ============================================================================
// Shortcuts Setup
// ============================================================================
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
let app_handle = app.clone();
// Avoid hijacking the system-wide Cmd+R on macOS.
#[cfg(target_os = "macos")]
let shortcuts = [
new_chat,
back,
forward,
new_window_shortcut,
show_app,
open_settings_shortcut,
];
#[cfg(not(target_os = "macos"))]
let shortcuts = [
new_chat,
reload,
back,
forward,
new_window_shortcut,
show_app,
open_settings_shortcut,
];
app.global_shortcut().on_shortcuts(
shortcuts,
move |_app, shortcut, _event| {
if shortcut == &new_chat {
trigger_new_chat(&app_handle);
}
if let Some(window) = app_handle.get_webview_window("main") {
if shortcut == &reload {
let _ = window.eval("window.location.reload()");
} else if shortcut == &back {
let _ = window.eval("window.history.back()");
} else if shortcut == &forward {
let _ = window.eval("window.history.forward()");
} else if shortcut == &open_settings_shortcut {
open_settings(&app_handle);
}
}
if shortcut == &new_window_shortcut {
trigger_new_window(&app_handle);
} else if shortcut == &show_app {
focus_main_window(&app_handle);
}
},
)?;
Ok(())
}
// ============================================================================
// Menu Setup
// ============================================================================
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
TRAY_MENU_OPEN_APP_ID,
"Open Onyx",
true,
Some("CmdOrCtrl+Shift+Space"),
None::<&str>,
)?;
let open_chat = MenuItem::with_id(
app,
@@ -666,7 +598,6 @@ fn main() {
tauri::Builder::default()
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
.plugin(tauri_plugin_window_state::Builder::default().build())
.manage(ConfigState {
config: RwLock::new(config),
@@ -698,11 +629,6 @@ fn main() {
.setup(move |app| {
let app_handle = app.handle();
// Setup global shortcuts
if let Err(e) = setup_shortcuts(&app_handle) {
eprintln!("Failed to setup shortcuts: {}", e);
}
if let Err(e) = setup_app_menu(&app_handle) {
eprintln!("Failed to setup menu: {}", e);
}

View File

@@ -22,6 +22,17 @@
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
}
.dark {
--background-900: #1a1a1a;
--background-800: #262626;
--text-light-05: rgba(255, 255, 255, 0.95);
--text-light-03: rgba(255, 255, 255, 0.6);
--white-10: rgba(255, 255, 255, 0.08);
--white-15: rgba(255, 255, 255, 0.12);
--white-20: rgba(255, 255, 255, 0.15);
--white-30: rgba(255, 255, 255, 0.25);
}
* {
box-sizing: border-box;
margin: 0;
@@ -30,7 +41,11 @@
body {
font-family: var(--font-hanken-grotesk);
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
background: linear-gradient(
135deg,
var(--background-900) 0%,
var(--background-800) 100%
);
min-height: 100vh;
color: var(--text-light-05);
display: flex;
@@ -39,6 +54,9 @@
padding: 20px;
-webkit-user-select: none;
user-select: none;
transition:
background 0.3s ease,
color 0.3s ease;
}
.titlebar {
@@ -69,16 +87,19 @@
}
.settings-panel {
background: linear-gradient(
to bottom,
rgba(255, 255, 255, 0.95),
rgba(245, 245, 245, 0.95)
);
background: var(--background-800);
backdrop-filter: blur(24px);
border-radius: 16px;
border: 1px solid var(--white-10);
overflow: hidden;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition:
background 0.3s ease,
border 0.3s ease;
}
.dark .settings-panel {
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
}
.settings-header {
@@ -93,17 +114,19 @@
width: 40px;
height: 40px;
border-radius: 12px;
background: white;
background: var(--background-900);
display: flex;
align-items: center;
justify-content: center;
overflow: hidden;
transition: background 0.3s ease;
}
.settings-icon svg {
width: 24px;
height: 24px;
color: #000;
color: var(--text-light-05);
transition: color 0.3s ease;
}
.settings-title {
@@ -134,9 +157,10 @@
}
.settings-group {
background: rgba(0, 0, 0, 0.03);
background: var(--background-900);
border-radius: 16px;
padding: 4px;
transition: background 0.3s ease;
}
.setting-row {
@@ -176,7 +200,7 @@
border: 1px solid var(--white-10);
border-radius: 8px;
font-size: 14px;
background: rgba(0, 0, 0, 0.05);
background: var(--background-800);
color: var(--text-light-05);
font-family: var(--font-hanken-grotesk);
transition: all 0.2s;
@@ -186,8 +210,8 @@
.input-field:focus {
outline: none;
border-color: var(--white-30);
background: rgba(0, 0, 0, 0.08);
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
background: var(--background-900);
box-shadow: 0 0 0 2px var(--white-10);
}
.input-field::placeholder {
@@ -231,7 +255,7 @@
left: 0;
right: 0;
bottom: 0;
background-color: rgba(0, 0, 0, 0.15);
background-color: var(--white-15);
transition: 0.3s;
border-radius: 24px;
}
@@ -243,14 +267,18 @@
width: 18px;
left: 3px;
bottom: 3px;
background-color: white;
background-color: var(--background-800);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
transition: 0.3s;
border-radius: 50%;
}
.dark .toggle-slider:before {
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
}
input:checked + .toggle-slider {
background-color: rgba(0, 0, 0, 0.3);
background-color: var(--white-30);
}
input:checked + .toggle-slider:before {
@@ -288,14 +316,15 @@
}
kbd {
background: rgba(0, 0, 0, 0.1);
border: 1px solid var(--white-10);
background: var(--white-10);
border: 1px solid var(--white-15);
border-radius: 4px;
padding: 2px 6px;
font-family: monospace;
font-weight: 500;
color: var(--text-light-05);
font-size: 11px;
transition: all 0.3s ease;
}
</style>
</head>
@@ -372,10 +401,34 @@
const errorMessage = document.getElementById("errorMessage");
const saveBtn = document.getElementById("saveBtn");
// Theme detection based on system preferences
function applySystemTheme() {
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
function updateTheme(e) {
if (e.matches) {
document.documentElement.classList.add("dark");
document.body.classList.add("dark");
} else {
document.documentElement.classList.remove("dark");
document.body.classList.remove("dark");
}
}
// Apply initial theme
updateTheme(darkModeQuery);
// Listen for changes
darkModeQuery.addEventListener("change", updateTheme);
}
function showSettings() {
document.body.classList.add("show-settings");
}
// Apply system theme immediately
applySystemTheme();
// Initialize the app
async function init() {
try {

View File

@@ -113,6 +113,23 @@
document.head.appendChild(style);
}
function updateTitleBarTheme(isDark) {
const titleBar = document.getElementById(TITLEBAR_ID);
if (!titleBar) return;
if (isDark) {
titleBar.style.background =
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
} else {
titleBar.style.background =
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
}
}
function buildTitleBar() {
const titleBar = document.createElement("div");
titleBar.id = TITLEBAR_ID;
@@ -134,6 +151,11 @@
}
});
// Apply initial styles matching current theme
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
// Apply styles matching Onyx design system with translucent glass effect
titleBar.style.cssText = `
position: fixed;
@@ -156,8 +178,12 @@
-webkit-backdrop-filter: blur(18px) saturate(180%);
-webkit-app-region: drag;
padding: 0 12px;
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
`;
// Apply correct theme
updateTitleBarTheme(isDark);
return titleBar;
}
@@ -168,6 +194,11 @@
const existing = document.getElementById(TITLEBAR_ID);
if (existing?.parentElement === document.body) {
// Update theme on existing titlebar
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
updateTitleBarTheme(isDark);
return;
}
@@ -178,6 +209,14 @@
const titleBar = buildTitleBar();
document.body.insertBefore(titleBar, document.body.firstChild);
injectStyles();
// Ensure theme is applied immediately after mount
setTimeout(() => {
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
updateTitleBarTheme(isDark);
}, 0);
}
function syncViewportHeight() {
@@ -194,9 +233,66 @@
}
}
function observeThemeChanges() {
let lastKnownTheme = null;
function checkAndUpdateTheme() {
// Check both html and body for dark class (some apps use body)
const htmlHasDark = document.documentElement.classList.contains("dark");
const bodyHasDark = document.body?.classList.contains("dark");
const isDark = htmlHasDark || bodyHasDark;
if (lastKnownTheme !== isDark) {
lastKnownTheme = isDark;
updateTitleBarTheme(isDark);
}
}
// Immediate check on setup
checkAndUpdateTheme();
// Watch for theme changes on the HTML element
const themeObserver = new MutationObserver(() => {
checkAndUpdateTheme();
});
themeObserver.observe(document.documentElement, {
attributes: true,
attributeFilter: ["class"],
});
// Also observe body if it exists
if (document.body) {
const bodyObserver = new MutationObserver(() => {
checkAndUpdateTheme();
});
bodyObserver.observe(document.body, {
attributes: true,
attributeFilter: ["class"],
});
}
// Also check periodically in case classList is manipulated directly
// or the theme loads asynchronously after page load
const intervalId = setInterval(() => {
checkAndUpdateTheme();
}, 300);
// Clean up after 30 seconds once theme should be stable
setTimeout(() => {
clearInterval(intervalId);
// But keep checking every 2 seconds for manual theme changes
setInterval(() => {
checkAndUpdateTheme();
}, 2000);
}, 30000);
}
function init() {
mountTitleBar();
syncViewportHeight();
observeThemeChanges();
window.addEventListener("resize", syncViewportHeight, { passive: true });
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
passive: true,

View File

@@ -119,7 +119,7 @@ backend = [
"shapely==2.0.6",
"stripe==10.12.0",
"urllib3==2.6.3",
"mistune==0.8.4",
"mistune==3.2.0",
"sendgrid==6.12.5",
"exa_py==1.15.4",
"braintrust==0.3.9",
@@ -142,7 +142,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.4.0",
"onyx-devtools==0.6.2",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",

26
uv.lock generated
View File

@@ -3897,11 +3897,11 @@ wheels = [
[[package]]
name = "mistune"
version = "0.8.4"
version = "3.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2d/a4/509f6e7783ddd35482feda27bc7f72e65b5e7dc910eca4ab2164daf9c577/mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", size = 58322, upload-time = "2018-10-11T06:59:27.908Z" }
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/09/ec/4b43dae793655b7d8a25f76119624350b4d65eb663459eb9603d7f1f0345/mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4", size = 16220, upload-time = "2018-10-11T06:59:26.044Z" },
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
]
[[package]]
@@ -4766,7 +4766,7 @@ requires-dist = [
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.25.0" },
{ name = "mistune", marker = "extra == 'backend'", specifier = "==0.8.4" },
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
@@ -4775,7 +4775,7 @@ requires-dist = [
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.4.0" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
{ name = "openai", specifier = "==2.14.0" },
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
@@ -4878,20 +4878,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
[[package]]
name = "onyx-devtools"
version = "0.4.0"
version = "0.6.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fastapi" },
{ name = "openapi-generator-cli" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/d8/f68d15c12d27d4525d10697ac7e2d67d6122fb59ccab219afb2973bc33ad/onyx_devtools-0.4.0-py3-none-any.whl", hash = "sha256:3eb821bce7ec8651d57e937d4d8483e1c2c4bc51df8cbab2dbcc05e3740ec96c", size = 2870841, upload-time = "2026-01-23T04:44:32.206Z" },
{ url = "https://files.pythonhosted.org/packages/28/04/6376342389494b51fd89e554dfdaf0d3809b8d1473bc9b72abd2d7dba21e/onyx_devtools-0.4.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:144e518abad3031ffef189445a69356fca1da2a4fb40c7b8431550133bfc4eef", size = 2890308, upload-time = "2026-01-23T04:44:37.674Z" },
{ url = "https://files.pythonhosted.org/packages/cc/c1/859b32fb3eff7e67179d971ace36313ae64e7fc9a242b45e606138b0041f/onyx_devtools-0.4.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0cc74d561f08a9c894adf8de79855b4fc72eb70e823a75e29db7f625ad366bd7", size = 2696160, upload-time = "2026-01-23T04:44:30.647Z" },
{ url = "https://files.pythonhosted.org/packages/59/1b/f1e3f574e9917779d22e3fcb28f8ac1888c250e7452a523f64a6ab8a1759/onyx_devtools-0.4.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d69de76a97d7f9ff8c473afffbf544a65265645d726f3d70cc12dbbd7e364222", size = 2602134, upload-time = "2026-01-23T04:44:31.716Z" },
{ url = "https://files.pythonhosted.org/packages/79/4a/a5d11640fdc23c9bf0e8617ce13793a587e49a64be2d20badf7e9b045e0a/onyx_devtools-0.4.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:fa84980ce8830e35432831aadc19ff465dbc723605aa80c50e0debc58457b70f", size = 2870864, upload-time = "2026-01-23T04:44:31.5Z" },
{ url = "https://files.pythonhosted.org/packages/fc/9f/6a7e02fbf47bcaea4d02b0ed92bea6e2c09408be7654fb3b57a1ba9863f2/onyx_devtools-0.4.0-py3-none-win_amd64.whl", hash = "sha256:8451efe3e137157696decf8b60a19fb3f0c52ae9f2d9b7c5bc6e667900e7c61e", size = 2953545, upload-time = "2026-01-23T04:44:38.11Z" },
{ url = "https://files.pythonhosted.org/packages/52/42/f7a5b99ade06d215fb99de41181d51a9a984f83afb15afa15ce79ecab635/onyx_devtools-0.4.0-py3-none-win_arm64.whl", hash = "sha256:53a5942c922d7049650e934c43f9c057d046f8d53bc68935ebf7e93baa29afc3", size = 2665984, upload-time = "2026-01-23T04:44:29.399Z" },
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
]
[[package]]

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 14 9"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 14 9"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,7 +1,9 @@
import type { SVGProps } from "react";
import type { IconProps } from "@opal/types";
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 15 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgBranch = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M4.75001 5C5.71651 5 6.50001 4.2165 6.50001 3.25C6.50001 2.2835 5.7165 1.5 4.75 1.5C3.78351 1.5 3.00001 2.2835 3.00001 3.25C3.00001 4.2165 3.78351 5 4.75001 5ZM4.75001 5L4.75001 6.24999M4.75 11C3.7835 11 3 11.7835 3 12.75C3 13.7165 3.7835 14.5 4.75 14.5C5.7165 14.5 6.5 13.7165 6.5 12.75C6.5 11.7835 5.71649 11 4.75 11ZM4.75 11L4.75001 6.24999M10.5 8.74997C10.5 9.71646 11.2835 10.5 12.25 10.5C13.2165 10.5 14 9.71646 14 8.74997C14 7.78347 13.2165 7 12.25 7C11.2835 7 10.5 7.78347 10.5 8.74997ZM10.5 8.74997L7.25001 8.74999C5.8693 8.74999 4.75001 7.6307 4.75001 6.24999"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgBranch;

View File

@@ -0,0 +1,16 @@
import type { IconProps } from "@opal/types";
const SvgCircle = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<circle cx="8" cy="8" r="6" strokeWidth={1.5} />
</svg>
);
export default SvgCircle;

View File

@@ -1,10 +1,12 @@
import React from "react";
import type { IconProps } from "@opal/types";
const SvgClaude = (props: IconProps) => {
const SvgClaude = ({ size, ...props }: IconProps) => {
const clipId = React.useId();
return (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgClipboard = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 9 14"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgDownload = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M14 10V12.6667C14 13.3929 13.3929 14 12.6667 14H3.33333C2.60711 14 2 13.3929 2 12.6667V10M4.66667 6.66667L8 10M8 10L11.3333 6.66667M8 10L8 2"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgDownload;

View File

@@ -24,6 +24,7 @@ export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBooksLineSmall } from "@opal/icons/books-line-small";
export { default as SvgBooksStackSmall } from "@opal/icons/books-stack-small";
export { default as SvgBracketCurly } from "@opal/icons/bracket-curly";
export { default as SvgBranch } from "@opal/icons/branch";
export { default as SvgBubbleText } from "@opal/icons/bubble-text";
export { default as SvgCalendar } from "@opal/icons/calendar";
export { default as SvgCheck } from "@opal/icons/check";
@@ -36,6 +37,7 @@ export { default as SvgChevronLeft } from "@opal/icons/chevron-left";
export { default as SvgChevronRight } from "@opal/icons/chevron-right";
export { default as SvgChevronUp } from "@opal/icons/chevron-up";
export { default as SvgChevronUpSmall } from "@opal/icons/chevron-up-small";
export { default as SvgCircle } from "@opal/icons/circle";
export { default as SvgClaude } from "@opal/icons/claude";
export { default as SvgClipboard } from "@opal/icons/clipboard";
export { default as SvgClock } from "@opal/icons/clock";
@@ -46,6 +48,7 @@ export { default as SvgCopy } from "@opal/icons/copy";
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
export { default as SvgCpu } from "@opal/icons/cpu";
export { default as SvgDevKit } from "@opal/icons/dev-kit";
export { default as SvgDownload } from "@opal/icons/download";
export { default as SvgDiscordMono } from "@opal/icons/DiscordMono";
export { default as SvgDownloadCloud } from "@opal/icons/download-cloud";
export { default as SvgEdit } from "@opal/icons/edit";
@@ -135,6 +138,7 @@ export { default as SvgStep3End } from "@opal/icons/step3-end";
export { default as SvgStop } from "@opal/icons/stop";
export { default as SvgStopCircle } from "@opal/icons/stop-circle";
export { default as SvgSun } from "@opal/icons/sun";
export { default as SvgTerminal } from "@opal/icons/terminal";
export { default as SvgTerminalSmall } from "@opal/icons/terminal-small";
export { default as SvgTextLinesSmall } from "@opal/icons/text-lines-small";
export { default as SvgThumbsDown } from "@opal/icons/thumbs-down";

View File

@@ -1,17 +1,11 @@
import type { IconProps } from "@opal/types";
const OnyxLogo = ({
width = 24,
height = 24,
className,
...props
}: IconProps) => (
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
<svg
width={width}
height={height}
width={size}
height={size}
viewBox="0 0 56 56"
xmlns="http://www.w3.org/2000/svg"
className={className}
stroke="currentColor"
{...props}
>
@@ -23,4 +17,4 @@ const OnyxLogo = ({
/>
</svg>
);
export default OnyxLogo;
export default SvgOnyxLogo;

View File

@@ -1,10 +1,12 @@
import React from "react";
import type { IconProps } from "@opal/types";
const SvgOpenAI = (props: IconProps) => {
const SvgOpenAI = ({ size, ...props }: IconProps) => {
const clipId = React.useId();
return (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -0,0 +1,22 @@
import type { IconProps } from "@opal/types";
const SvgTerminal = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2.66667 11.3333L6.66667 7.33331L2.66667 3.33331M8.00001 12.6666H13.3333"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgTerminal;

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { IconProps } from "@opal/types";
const SvgUserPlus = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,6 +1,8 @@
import type { SVGProps } from "react";
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
import type { IconProps } from "@opal/types";
const SvgWallet = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -300,13 +300,7 @@ export default function Page() {
<>
<AdminPageTitle
title="Default Assistant"
icon={
<SvgOnyxLogo
width={32}
height={32}
className="my-auto stroke-text-04"
/>
}
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
/>
<DefaultAssistantConfig />
</>

View File

@@ -31,6 +31,7 @@ import { fetchBedrockModels } from "../utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import Tabs from "@/refresh-components/Tabs";
import { cn } from "@/lib/utils";
export const BEDROCK_PROVIDER_NAME = "bedrock";
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
@@ -135,7 +136,7 @@ function BedrockFormInternals({
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
return (
<Form className={LLM_FORM_CLASS_NAME}>
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
<DisplayNameField disabled={!!existingLlmProvider} />
<SelectorFormField
@@ -176,7 +177,7 @@ function BedrockFormInternals({
</Tabs.Content>
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4 w-full">
<TextFormField
name={FIELD_AWS_ACCESS_KEY_ID}
label="AWS Access Key ID"
@@ -191,7 +192,7 @@ function BedrockFormInternals({
</Tabs.Content>
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4 w-full">
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
label="AWS Bedrock Long-term API Key"

View File

@@ -131,10 +131,15 @@ export function CustomForm({
return;
}
const selectedModelNames = modelConfigurations.map(
(config) => config.name
);
await submitLLMProvider({
providerName: values.provider,
values: {
...values,
selected_model_names: selectedModelNames,
custom_config: customConfigProcessing(
values.custom_config_list
),

View File

@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
interface OllamaFormContentProps {
formikProps: FormikProps<OllamaFormValues>;
existingLlmProvider?: LLMProviderView;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
testError: string;
mutate: () => void;
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
function OllamaFormContent({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
testError,
mutate,
onClose,
isFormValid,
}: OllamaFormContentProps) {
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
existingLlmProvider?.model_configurations || []
);
const [isLoadingModels, setIsLoadingModels] = useState(true);
useEffect(() => {
@@ -70,16 +71,25 @@ function OllamaFormContent({
.then((data) => {
if (data.error) {
console.error("Error fetching models:", data.error);
setAvailableModels([]);
setFetchedModels([]);
return;
}
setAvailableModels(data.models);
setFetchedModels(data.models);
})
.finally(() => {
setIsLoadingModels(false);
});
}
}, [formikProps.values.api_base]);
}, [
formikProps.values.api_base,
existingLlmProvider?.name,
setFetchedModels,
]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
return (
<Form className={LLM_FORM_CLASS_NAME}>
@@ -99,7 +109,7 @@ function OllamaFormContent({
/>
<DisplayModels
modelConfigurations={availableModels}
modelConfigurations={currentModels}
formikProps={formikProps}
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
isLoading={isLoadingModels}
@@ -125,6 +135,8 @@ export function OllamaForm({
existingLlmProvider,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
return (
<ProviderFormEntrypointWrapper
providerName="Ollama"
@@ -189,7 +201,10 @@ export function OllamaForm({
providerName: OLLAMA_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations,
modelConfigurations:
fetchedModels.length > 0
? fetchedModels
: modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
@@ -205,6 +220,8 @@ export function OllamaForm({
<OllamaFormContent
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
testError={testError}
mutate={mutate}

View File

@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
<SvgArrowExchange className="size-3 text-text-04" />
</div>
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
<SvgOnyxLogo
width={24}
height={24}
className="text-text-04 shrink-0"
/>
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
</div>
</div>
);

View File

@@ -1168,7 +1168,7 @@ export default function Page() {
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo width={16} height={16} />
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
@@ -1381,7 +1381,7 @@ export default function Page() {
} logo`,
fallback:
selectedContentProviderType === "onyx_web_crawler" ? (
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
<SvgOnyxLogo size={24} className="text-text-05" />
) : undefined,
size: 24,
containerSize: 28,

View File

@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
/>
)}
<BackButton
behaviorOverride={() => router.push("/admin/indexing/status")}
/>
<BackButton />
<div
className="flex
items-center

View File

@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
import { useAgents } from "@/hooks/useAgents";
import { ChatPopup } from "@/app/chat/components/ChatPopup";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
import { useUser } from "@/components/user/UserProvider";
import NoAssistantModal from "@/components/modals/NoAssistantModal";
import TextView from "@/components/chat/TextView";
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
const retrievalEnabled = useMemo(() => {
if (liveAssistant) {
return liveAssistant.tools.some(
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
);
return personaIncludesRetrieval(liveAssistant);
}
return false;
}, [liveAssistant]);

View File

@@ -643,6 +643,7 @@ export function useChatController({
let toolCall: ToolCallMetadata | null = null;
let files = projectFilesToFileDescriptors(currentMessageFiles);
let packets: Packet[] = [];
let packetsVersion = 0;
let newUserMessageId: number | null = null;
let newAssistantMessageId: number | null = null;
@@ -729,7 +730,6 @@ export function useChatController({
if (!packet) {
continue;
}
console.debug("Packet:", JSON.stringify(packet));
// We've processed initial packets and are starting to stream content.
// Transition from 'loading' to 'streaming'.
@@ -800,8 +800,8 @@ export function useChatController({
updateCanContinue(true, frozenSessionId);
}
} else if (Object.hasOwn(packet, "obj")) {
console.debug("Object packet:", JSON.stringify(packet));
packets.push(packet as Packet);
packetsVersion++;
// Check if the packet contains document information
const packetObj = (packet as Packet).obj;
@@ -859,6 +859,8 @@ export function useChatController({
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
packets: packets,
packetsVersion: packetsVersion,
packetCount: packets.length,
},
],
// Pass the latest map state
@@ -885,6 +887,7 @@ export function useChatController({
toolCall: null,
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
packets: [],
packetCount: 0,
},
{
nodeId: initialAssistantNode.nodeId,
@@ -894,6 +897,7 @@ export function useChatController({
toolCall: null,
parentNodeId: initialUserNode.nodeId,
packets: [],
packetCount: 0,
stackTrace: stackTrace,
errorCode: errorCode,
isRetryable: isRetryable,

View File

@@ -139,6 +139,9 @@ export interface Message {
// new gen
packets: Packet[];
// Version counter for efficient memo comparison (increments with each packet)
packetsVersion?: number;
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
// cached values for easy access
documents?: OnyxDocument[] | null;

View File

@@ -74,6 +74,8 @@ export type RegenerationFactory = (regenerationRequest: {
export interface AIMessageProps {
rawPackets: Packet[];
// Version counter for efficient memo comparison (avoids array copying)
packetsVersion?: number;
chatState: FullChatState;
nodeId: number;
messageId?: number;
@@ -88,8 +90,6 @@ export interface AIMessageProps {
}
// TODO: Consider more robust comparisons:
// - `rawPackets.length` assumes packets are append-only. Could compare the last
// packet or use a shallow comparison if packets can be modified in place.
// - `chatState.docs`, `chatState.citations`, and `otherMessagesCanSwitchTo` use
// reference equality. Shallow array/object comparison would be more robust if
// these are recreated with the same values.
@@ -98,7 +98,7 @@ function arePropsEqual(prev: AIMessageProps, next: AIMessageProps): boolean {
prev.nodeId === next.nodeId &&
prev.messageId === next.messageId &&
prev.currentFeedback === next.currentFeedback &&
prev.rawPackets.length === next.rawPackets.length &&
prev.packetsVersion === next.packetsVersion &&
prev.chatState.assistant?.id === next.chatState.assistant?.id &&
prev.chatState.docs === next.chatState.docs &&
prev.chatState.citations === next.chatState.citations &&

View File

@@ -11,6 +11,7 @@ import { CitationMap } from "../../interfaces";
export enum RenderType {
HIGHLIGHT = "highlight",
FULL = "full",
COMPACT = "compact",
}
export interface FullChatState {
@@ -35,6 +36,9 @@ export interface RendererResult {
// used for things that should just show text w/o an icon or header
// e.g. ReasoningRenderer
expandedText?: JSX.Element;
// Whether this renderer supports compact mode (collapse button shown only when true)
supportsCompact?: boolean;
}
export type MessageRenderer<
@@ -48,5 +52,7 @@ export type MessageRenderer<
animate: boolean;
stopPacketSeen: boolean;
stopReason?: StopReason;
/** Whether this is the last step in the timeline (for connector line decisions) */
isLastStep?: boolean;
children: (result: RendererResult) => JSX.Element;
}>;

Some files were not shown because too many files have changed in this diff Show More