mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 13:15:44 +00:00
Compare commits
30 Commits
litellm_pr
...
v2.11.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed46504a1a | ||
|
|
7a24b34516 | ||
|
|
7a7ffa9051 | ||
|
|
3053ab518c | ||
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de |
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -341,6 +341,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,12 +1,149 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -32,6 +32,7 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -45,6 +45,7 @@ class RedisConnectorPrune:
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
|
||||
@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -84,7 +84,8 @@ def patch_document_set(
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
@@ -125,7 +126,8 @@ def delete_document_set(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
object_is_public=document_set.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
@router.get("/", tags=PUBLIC_API_TAGS)
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
|
||||
|
||||
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
|
||||
async def get_auth_type() -> AuthTypeResponse:
|
||||
user_count = await get_user_count()
|
||||
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
|
||||
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
|
||||
# associated with until they auth.
|
||||
has_users = True
|
||||
if AUTH_TYPE != AuthType.CLOUD:
|
||||
user_count = await get_user_count()
|
||||
has_users = user_count > 0
|
||||
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
password_min_length=PASSWORD_MIN_LENGTH,
|
||||
has_users=user_count > 0,
|
||||
has_users=has_users,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -573,7 +573,7 @@ mcp==1.25.0
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==0.8.4
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
|
||||
@@ -298,7 +298,7 @@ numpy==2.4.1
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.4.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProjectManager:
|
||||
) -> List[UserProjectSnapshot]:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -56,7 +56,7 @@ class ProjectManager:
|
||||
) -> bool:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
|
||||
class TestConvertToMetadataValue:
|
||||
"""Tests for the _convert_to_metadata_value helper function."""
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""String values should be returned as-is."""
|
||||
assert _convert_to_metadata_value("hello") == "hello"
|
||||
assert _convert_to_metadata_value("") == ""
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
"""Boolean True should be converted to string 'True'."""
|
||||
assert _convert_to_metadata_value(True) == "True"
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
"""Boolean False should be converted to string 'False'."""
|
||||
assert _convert_to_metadata_value(False) == "False"
|
||||
|
||||
def test_integer_value(self) -> None:
|
||||
"""Integer values should be converted to string."""
|
||||
assert _convert_to_metadata_value(42) == "42"
|
||||
assert _convert_to_metadata_value(0) == "0"
|
||||
assert _convert_to_metadata_value(-100) == "-100"
|
||||
|
||||
def test_float_value(self) -> None:
|
||||
"""Float values should be converted to string."""
|
||||
assert _convert_to_metadata_value(3.14) == "3.14"
|
||||
assert _convert_to_metadata_value(0.0) == "0.0"
|
||||
assert _convert_to_metadata_value(-2.5) == "-2.5"
|
||||
|
||||
def test_list_of_strings(self) -> None:
|
||||
"""List of strings should remain as list of strings."""
|
||||
result = _convert_to_metadata_value(["a", "b", "c"])
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_list_of_mixed_types(self) -> None:
|
||||
"""List with mixed types should have all items converted to strings."""
|
||||
result = _convert_to_metadata_value([1, True, 3.14, "text"])
|
||||
assert result == ["1", "True", "3.14", "text"]
|
||||
|
||||
def test_empty_list(self) -> None:
|
||||
"""Empty list should return empty list."""
|
||||
assert _convert_to_metadata_value([]) == []
|
||||
|
||||
|
||||
class TestYieldDocBatches:
|
||||
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
|
||||
|
||||
@pytest.fixture
|
||||
def connector(self) -> SalesforceConnector:
|
||||
"""Create a SalesforceConnector instance with mocked sf_client."""
|
||||
connector = SalesforceConnector(
|
||||
batch_size=10,
|
||||
requested_objects=["Opportunity"],
|
||||
)
|
||||
# Mock the sf_client property
|
||||
mock_sf_client = MagicMock()
|
||||
mock_sf_client.sf_instance = "test.salesforce.com"
|
||||
connector._sf_client = mock_sf_client
|
||||
return connector
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sf_db(self) -> MagicMock:
|
||||
"""Create a mock OnyxSalesforceSQLite object."""
|
||||
return MagicMock()
|
||||
|
||||
def _create_salesforce_object(
|
||||
self,
|
||||
object_id: str,
|
||||
object_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> SalesforceObject:
|
||||
"""Helper to create a SalesforceObject with required fields."""
|
||||
# Ensure required fields are present
|
||||
data.setdefault(ID_FIELD, object_id)
|
||||
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
|
||||
data.setdefault(NAME_FIELD, f"Test {object_type}")
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_metadata_type_conversion_for_opportunity(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that Opportunity metadata fields are properly type-converted."""
|
||||
parent_id = "006bm000006kyDpAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create a parent object with various data types in the fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Test Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"Account": "Acme Corp", # string - should become "account" metadata
|
||||
"FiscalQuarter": 2, # int - should be converted to "2"
|
||||
"FiscalYear": 2024, # int - should be converted to "2024"
|
||||
"IsClosed": False, # bool - should be converted to "False"
|
||||
"StageName": "Prospecting", # string
|
||||
"Type": "New Business", # string
|
||||
"Amount": 50000.50, # float - should be converted to "50000.50"
|
||||
"CloseDate": "2024-06-30", # string
|
||||
"Probability": 75, # int - should be converted to "75"
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
# Setup mock sf_db
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create a mock document that convert_sf_object_to_doc will return
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Test Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
# Track parent changes
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# Call _yield_doc_batches
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify we got one batch with one document
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify metadata type conversions
|
||||
# All values should be strings (or list of strings)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["account"] == "Acme Corp" # string stays string
|
||||
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
|
||||
assert doc.metadata["fiscal_year"] == "2024" # int -> str
|
||||
assert doc.metadata["is_closed"] == "False" # bool -> str
|
||||
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
|
||||
assert doc.metadata["type"] == "New Business" # string stays string
|
||||
assert (
|
||||
doc.metadata["amount"] == "50000.5"
|
||||
) # float -> str (Python drops trailing zeros)
|
||||
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
|
||||
assert doc.metadata["probability"] == "75" # int -> str
|
||||
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
|
||||
|
||||
# Verify parent was counted
|
||||
assert parents_changed == 1
|
||||
assert type_to_processed[parent_type] == 1
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_missing_optional_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing optional metadata fields are not added."""
|
||||
parent_id = "006bm000006kyDqAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create parent object with only some fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Minimal Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"StageName": "Closed Won",
|
||||
# Notably missing: Amount, Probability, FiscalQuarter, etc.
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Minimal Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only present fields should be in metadata
|
||||
assert "stage_name" in doc.metadata
|
||||
assert doc.metadata["stage_name"] == "Closed Won"
|
||||
assert "name" in doc.metadata
|
||||
assert doc.metadata["name"] == "Minimal Opportunity"
|
||||
|
||||
# Missing fields should not be in metadata
|
||||
assert "amount" not in doc.metadata
|
||||
assert "probability" not in doc.metadata
|
||||
assert "fiscal_quarter" not in doc.metadata
|
||||
assert "fiscal_year" not in doc.metadata
|
||||
assert "is_closed" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_contact_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test metadata conversion for Contact object type."""
|
||||
parent_id = "003bm00000EjHCjAAN"
|
||||
parent_type = "Contact"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "John Doe",
|
||||
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
|
||||
"Account": "Globex Corp",
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z",
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="John Doe",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify Contact-specific metadata
|
||||
assert doc.metadata["object_type"] == "Contact"
|
||||
assert doc.metadata["account"] == "Globex Corp"
|
||||
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
|
||||
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_no_default_attributes_for_unknown_type(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that unknown object types only get object_type metadata."""
|
||||
parent_id = "001bm00000fd9Z3AAI"
|
||||
parent_type = "CustomObject__c"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Custom Record",
|
||||
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
|
||||
"CustomField__c": "custom value",
|
||||
"NumberField__c": 123,
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Custom Record",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only object_type should be set for unknown types
|
||||
assert doc.metadata["object_type"] == "CustomObject__c"
|
||||
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
|
||||
assert "CustomField__c" not in doc.metadata
|
||||
assert "NumberField__c" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_skips_missing_parent_objects(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing parent objects are skipped gracefully."""
|
||||
parent_id = "006bm000006kyDrAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# get_record returns None for missing object
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = None
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Should yield one empty batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 0
|
||||
|
||||
# convert_sf_object_to_doc should not have been called
|
||||
mock_convert.assert_not_called()
|
||||
|
||||
# Parents changed should still be 0
|
||||
assert parents_changed == 0
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_multiple_documents_batching(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that multiple documents are correctly batched."""
|
||||
# Create 3 parent objects
|
||||
parent_ids = [
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
]
|
||||
parent_type = "Opportunity"
|
||||
|
||||
parent_objects = [
|
||||
self._create_salesforce_object(
|
||||
pid,
|
||||
parent_type,
|
||||
{
|
||||
ID_FIELD: pid,
|
||||
NAME_FIELD: f"Opportunity {i}",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"IsClosed": i % 2 == 0, # alternating bool values
|
||||
"Amount": 1000.0 * (i + 1),
|
||||
},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
|
||||
# Setup mock to return all three
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
|
||||
)
|
||||
mock_sf_db.get_record.side_effect = parent_objects
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create mock documents
|
||||
mock_docs = [
|
||||
Document(
|
||||
id=f"SALESFORCE_{pid}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=f"Opportunity {i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
mock_convert.side_effect = mock_docs
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
# With batch_size=10, all 3 docs should be in one batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 3
|
||||
|
||||
# Verify each document has correct metadata
|
||||
for i, doc in enumerate(batches[0]):
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["is_closed"] == str(i % 2 == 0)
|
||||
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
|
||||
|
||||
assert type_to_processed[parent_type] == 3
|
||||
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
|
||||
|
||||
def _mock_user(
|
||||
user_id: UUID | None = None, email: str = "test@example.com"
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.oauth_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
def _make_query_chain() -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
return chain
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_nulls_out_document_set_ownership(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Verify DocumentSet.user_id is nulled out (update, not delete)
|
||||
doc_set_chain = query_chains[DocumentSet]
|
||||
doc_set_chain.filter.assert_called()
|
||||
doc_set_chain.filter.return_value.update.assert_called_once_with(
|
||||
{DocumentSet.user_id: None}
|
||||
)
|
||||
|
||||
# Verify Persona.user_id is nulled out (update, not delete)
|
||||
persona_chain = query_chains[Persona]
|
||||
persona_chain.filter.assert_called()
|
||||
persona_chain.filter.return_value.update.assert_called_once_with(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_cleans_up_join_tables(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Join tables should be deleted (not updated)
|
||||
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
|
||||
chain = query_chains[model]
|
||||
chain.filter.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_commits_and_removes_invited(
|
||||
_mock_ee: Any, mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user(email="deleted@example.com")
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_called_once_with(user)
|
||||
db_session.commit.assert_called_once()
|
||||
mock_remove_invited.assert_called_once_with("deleted@example.com")
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_deletes_oauth_accounts(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
oauth1 = MagicMock()
|
||||
oauth2 = MagicMock()
|
||||
user.oauth_accounts = [oauth1, oauth2]
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_any_call(oauth1)
|
||||
db_session.delete.assert_any_call(oauth2)
|
||||
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
3
desktop/.gitignore
vendored
3
desktop/.gitignore
vendored
@@ -22,3 +22,6 @@ npm-debug.log*
|
||||
# Local env files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Generated files
|
||||
src-tauri/gen/schemas/acl-manifests.json
|
||||
|
||||
96
desktop/src-tauri/Cargo.lock
generated
96
desktop/src-tauri/Cargo.lock
generated
@@ -706,16 +706,6 @@ dependencies = [
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -993,16 +983,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -1122,24 +1102,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "global-hotkey"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"keyboard-types",
|
||||
"objc2 0.6.3",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"windows-sys 0.59.0",
|
||||
"x11rb",
|
||||
"xkeysym",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -1713,12 +1675,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
@@ -2248,7 +2204,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-global-shortcut",
|
||||
"tauri-plugin-shell",
|
||||
"tauri-plugin-window-state",
|
||||
"tokio",
|
||||
@@ -2878,19 +2833,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -3605,21 +3547,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-global-shortcut"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
|
||||
dependencies = [
|
||||
"global-hotkey",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-shell"
|
||||
version = "2.3.3"
|
||||
@@ -5021,29 +4948,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xkeysym"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-global-shortcut = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -20,7 +20,6 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shortcuts Setup
|
||||
// ============================================================================
|
||||
|
||||
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
|
||||
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
|
||||
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
// Avoid hijacking the system-wide Cmd+R on macOS.
|
||||
#[cfg(target_os = "macos")]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
reload,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
shortcuts,
|
||||
move |_app, shortcut, _event| {
|
||||
if shortcut == &new_chat {
|
||||
trigger_new_chat(&app_handle);
|
||||
}
|
||||
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if shortcut == &reload {
|
||||
let _ = window.eval("window.location.reload()");
|
||||
} else if shortcut == &back {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if shortcut == &new_window_shortcut {
|
||||
trigger_new_window(&app_handle);
|
||||
} else if shortcut == &show_app {
|
||||
focus_main_window(&app_handle);
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+Space"),
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
@@ -666,7 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -698,11 +629,6 @@ fn main() {
|
||||
.setup(move |app| {
|
||||
let app_handle = app.handle();
|
||||
|
||||
// Setup global shortcuts
|
||||
if let Err(e) = setup_shortcuts(&app_handle) {
|
||||
eprintln!("Failed to setup shortcuts: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = setup_app_menu(&app_handle) {
|
||||
eprintln!("Failed to setup menu: {}", e);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background-900: #1a1a1a;
|
||||
--background-800: #262626;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.08);
|
||||
--white-15: rgba(255, 255, 255, 0.12);
|
||||
--white-20: rgba(255, 255, 255, 0.15);
|
||||
--white-30: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
@@ -30,7 +41,11 @@
|
||||
|
||||
body {
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
@@ -39,6 +54,9 @@
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
color 0.3s ease;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
@@ -69,16 +87,19 @@
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
background: var(--background-800);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
border 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .settings-panel {
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
@@ -93,17 +114,19 @@
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
background: var(--background-900);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
color: var(--text-light-05);
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
@@ -134,9 +157,10 @@
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
background: var(--background-900);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
@@ -176,7 +200,7 @@
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-800);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
@@ -186,8 +210,8 @@
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-900);
|
||||
box-shadow: 0 0 0 2px var(--white-10);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
@@ -231,7 +255,7 @@
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
background-color: var(--white-15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
@@ -243,14 +267,18 @@
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
background-color: var(--background-800);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.dark .toggle-slider:before {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
background-color: var(--white-30);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
@@ -288,14 +316,15 @@
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-15);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
@@ -372,10 +401,34 @@
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
// Theme detection based on system preferences
|
||||
function applySystemTheme() {
|
||||
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
||||
|
||||
function updateTheme(e) {
|
||||
if (e.matches) {
|
||||
document.documentElement.classList.add("dark");
|
||||
document.body.classList.add("dark");
|
||||
} else {
|
||||
document.documentElement.classList.remove("dark");
|
||||
document.body.classList.remove("dark");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply initial theme
|
||||
updateTheme(darkModeQuery);
|
||||
|
||||
// Listen for changes
|
||||
darkModeQuery.addEventListener("change", updateTheme);
|
||||
}
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Apply system theme immediately
|
||||
applySystemTheme();
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
|
||||
@@ -113,6 +113,23 @@
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function updateTitleBarTheme(isDark) {
|
||||
const titleBar = document.getElementById(TITLEBAR_ID);
|
||||
if (!titleBar) return;
|
||||
|
||||
if (isDark) {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
|
||||
} else {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
|
||||
}
|
||||
}
|
||||
|
||||
function buildTitleBar() {
|
||||
const titleBar = document.createElement("div");
|
||||
titleBar.id = TITLEBAR_ID;
|
||||
@@ -134,6 +151,11 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Apply initial styles matching current theme
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
// Apply styles matching Onyx design system with translucent glass effect
|
||||
titleBar.style.cssText = `
|
||||
position: fixed;
|
||||
@@ -156,8 +178,12 @@
|
||||
-webkit-backdrop-filter: blur(18px) saturate(180%);
|
||||
-webkit-app-region: drag;
|
||||
padding: 0 12px;
|
||||
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
|
||||
`;
|
||||
|
||||
// Apply correct theme
|
||||
updateTitleBarTheme(isDark);
|
||||
|
||||
return titleBar;
|
||||
}
|
||||
|
||||
@@ -168,6 +194,11 @@
|
||||
|
||||
const existing = document.getElementById(TITLEBAR_ID);
|
||||
if (existing?.parentElement === document.body) {
|
||||
// Update theme on existing titlebar
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,6 +209,14 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
|
||||
// Ensure theme is applied immediately after mount
|
||||
setTimeout(() => {
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
@@ -194,9 +233,66 @@
|
||||
}
|
||||
}
|
||||
|
||||
function observeThemeChanges() {
|
||||
let lastKnownTheme = null;
|
||||
|
||||
function checkAndUpdateTheme() {
|
||||
// Check both html and body for dark class (some apps use body)
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
if (lastKnownTheme !== isDark) {
|
||||
lastKnownTheme = isDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}
|
||||
}
|
||||
|
||||
// Immediate check on setup
|
||||
checkAndUpdateTheme();
|
||||
|
||||
// Watch for theme changes on the HTML element
|
||||
const themeObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
|
||||
themeObserver.observe(document.documentElement, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
|
||||
// Also observe body if it exists
|
||||
if (document.body) {
|
||||
const bodyObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
bodyObserver.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
}
|
||||
|
||||
// Also check periodically in case classList is manipulated directly
|
||||
// or the theme loads asynchronously after page load
|
||||
const intervalId = setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 300);
|
||||
|
||||
// Clean up after 30 seconds once theme should be stable
|
||||
setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
// But keep checking every 2 seconds for manual theme changes
|
||||
setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 2000);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
function init() {
|
||||
mountTitleBar();
|
||||
syncViewportHeight();
|
||||
observeThemeChanges();
|
||||
|
||||
window.addEventListener("resize", syncViewportHeight, { passive: true });
|
||||
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
|
||||
passive: true,
|
||||
|
||||
@@ -119,7 +119,7 @@ backend = [
|
||||
"shapely==2.0.6",
|
||||
"stripe==10.12.0",
|
||||
"urllib3==2.6.3",
|
||||
"mistune==0.8.4",
|
||||
"mistune==3.2.0",
|
||||
"sendgrid==6.12.5",
|
||||
"exa_py==1.15.4",
|
||||
"braintrust==0.3.9",
|
||||
@@ -142,7 +142,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.4.0",
|
||||
"onyx-devtools==0.6.2",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
26
uv.lock
generated
26
uv.lock
generated
@@ -3897,11 +3897,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mistune"
|
||||
version = "0.8.4"
|
||||
version = "3.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2d/a4/509f6e7783ddd35482feda27bc7f72e65b5e7dc910eca4ab2164daf9c577/mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e", size = 58322, upload-time = "2018-10-11T06:59:27.908Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/09/ec/4b43dae793655b7d8a25f76119624350b4d65eb663459eb9603d7f1f0345/mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4", size = 16220, upload-time = "2018-10-11T06:59:26.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4766,7 +4766,7 @@ requires-dist = [
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
|
||||
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.25.0" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==0.8.4" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
|
||||
@@ -4775,7 +4775,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.5.9" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.4.0" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4878,20 +4878,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.4.0"
|
||||
version = "0.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/d8/f68d15c12d27d4525d10697ac7e2d67d6122fb59ccab219afb2973bc33ad/onyx_devtools-0.4.0-py3-none-any.whl", hash = "sha256:3eb821bce7ec8651d57e937d4d8483e1c2c4bc51df8cbab2dbcc05e3740ec96c", size = 2870841, upload-time = "2026-01-23T04:44:32.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/04/6376342389494b51fd89e554dfdaf0d3809b8d1473bc9b72abd2d7dba21e/onyx_devtools-0.4.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:144e518abad3031ffef189445a69356fca1da2a4fb40c7b8431550133bfc4eef", size = 2890308, upload-time = "2026-01-23T04:44:37.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/c1/859b32fb3eff7e67179d971ace36313ae64e7fc9a242b45e606138b0041f/onyx_devtools-0.4.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0cc74d561f08a9c894adf8de79855b4fc72eb70e823a75e29db7f625ad366bd7", size = 2696160, upload-time = "2026-01-23T04:44:30.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/1b/f1e3f574e9917779d22e3fcb28f8ac1888c250e7452a523f64a6ab8a1759/onyx_devtools-0.4.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d69de76a97d7f9ff8c473afffbf544a65265645d726f3d70cc12dbbd7e364222", size = 2602134, upload-time = "2026-01-23T04:44:31.716Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/4a/a5d11640fdc23c9bf0e8617ce13793a587e49a64be2d20badf7e9b045e0a/onyx_devtools-0.4.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:fa84980ce8830e35432831aadc19ff465dbc723605aa80c50e0debc58457b70f", size = 2870864, upload-time = "2026-01-23T04:44:31.5Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/9f/6a7e02fbf47bcaea4d02b0ed92bea6e2c09408be7654fb3b57a1ba9863f2/onyx_devtools-0.4.0-py3-none-win_amd64.whl", hash = "sha256:8451efe3e137157696decf8b60a19fb3f0c52ae9f2d9b7c5bc6e667900e7c61e", size = 2953545, upload-time = "2026-01-23T04:44:38.11Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/42/f7a5b99ade06d215fb99de41181d51a9a984f83afb15afa15ce79ecab635/onyx_devtools-0.4.0-py3-none-win_arm64.whl", hash = "sha256:53a5942c922d7049650e934c43f9c057d046f8d53bc68935ebf7e93baa29afc3", size = 2665984, upload-time = "2026-01-23T04:44:29.399Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { SVGProps } from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
|
||||
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 15 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgClaude = (props: IconProps) => {
|
||||
const SvgClaude = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgClipboard = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const OnyxLogo = ({
|
||||
width = 24,
|
||||
height = 24,
|
||||
className,
|
||||
...props
|
||||
}: IconProps) => (
|
||||
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={width}
|
||||
height={height}
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
@@ -23,4 +17,4 @@ const OnyxLogo = ({
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default OnyxLogo;
|
||||
export default SvgOnyxLogo;
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgOpenAI = (props: IconProps) => {
|
||||
const SvgOpenAI = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgUserPlus = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgWallet = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -300,13 +300,7 @@ export default function Page() {
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={
|
||||
<SvgOnyxLogo
|
||||
width={32}
|
||||
height={32}
|
||||
className="my-auto stroke-text-04"
|
||||
/>
|
||||
}
|
||||
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</>
|
||||
|
||||
@@ -131,10 +131,15 @@ export function CustomForm({
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(
|
||||
values.custom_config_list
|
||||
),
|
||||
|
||||
@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
|
||||
function OllamaFormContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
|
||||
existingLlmProvider?.model_configurations || []
|
||||
);
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -70,16 +71,25 @@ function OllamaFormContent({
|
||||
.then((data) => {
|
||||
if (data.error) {
|
||||
console.error("Error fetching models:", data.error);
|
||||
setAvailableModels([]);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setAvailableModels(data.models);
|
||||
setFetchedModels(data.models);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingModels(false);
|
||||
});
|
||||
}
|
||||
}, [formikProps.values.api_base]);
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
existingLlmProvider?.name,
|
||||
setFetchedModels,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
@@ -99,7 +109,7 @@ function OllamaFormContent({
|
||||
/>
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={availableModels}
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
|
||||
isLoading={isLoadingModels}
|
||||
@@ -125,6 +135,8 @@ export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
@@ -189,7 +201,10 @@ export function OllamaForm({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
@@ -205,6 +220,8 @@ export function OllamaForm({
|
||||
<OllamaFormContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
testError={testError}
|
||||
mutate={mutate}
|
||||
|
||||
@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo
|
||||
width={24}
|
||||
height={24}
|
||||
className="text-text-04 shrink-0"
|
||||
/>
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1168,7 +1168,7 @@ export default function Page() {
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={16} height={16} />
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
@@ -1381,7 +1381,7 @@ export default function Page() {
|
||||
} logo`,
|
||||
fallback:
|
||||
selectedContentProviderType === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
|
||||
<SvgOnyxLogo size={24} className="text-text-05" />
|
||||
) : undefined,
|
||||
size: 24,
|
||||
containerSize: 28,
|
||||
|
||||
@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
<BackButton />
|
||||
<div
|
||||
className="flex
|
||||
items-center
|
||||
|
||||
@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { ChatPopup } from "@/app/chat/components/ChatPopup";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextView from "@/components/chat/TextView";
|
||||
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
if (liveAssistant) {
|
||||
return liveAssistant.tools.some(
|
||||
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
return personaIncludesRetrieval(liveAssistant);
|
||||
}
|
||||
return false;
|
||||
}, [liveAssistant]);
|
||||
|
||||
@@ -63,7 +63,7 @@ export type ProjectDetails = {
|
||||
};
|
||||
|
||||
export async function fetchProjects(): Promise<Project[]> {
|
||||
const response = await fetch("/api/user/projects/");
|
||||
const response = await fetch("/api/user/projects");
|
||||
if (!response.ok) {
|
||||
handleRequestError("Fetch projects", response);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
.card {
|
||||
@apply rounded-16 w-full;
|
||||
@apply rounded-16 w-full overflow-clip;
|
||||
}
|
||||
|
||||
.card[data-variant="primary"] {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { User } from "@/lib/types";
|
||||
import { FiPlus, FiX } from "react-icons/fi";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import { UsersIcon } from "@/components/icons/icons";
|
||||
import { FiX } from "react-icons/fi";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
interface UserEditorProps {
|
||||
@@ -51,35 +50,27 @@ export const UserEditor = ({
|
||||
</div>
|
||||
|
||||
<div className="flex">
|
||||
<SearchMultiSelectDropdown
|
||||
<InputComboBox
|
||||
placeholder="Search..."
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={(selectedValue) => {
|
||||
setSelectedUserIds([
|
||||
...Array.from(new Set([...selectedUserIds, selectedValue])),
|
||||
]);
|
||||
}}
|
||||
options={allUsers
|
||||
.filter(
|
||||
(user) =>
|
||||
!selectedUserIds.includes(user.id) &&
|
||||
!existingUsers.map((user) => user.id).includes(user.id)
|
||||
)
|
||||
.map((user) => {
|
||||
return {
|
||||
name: user.email,
|
||||
value: user.id,
|
||||
};
|
||||
})}
|
||||
onSelect={(option) => {
|
||||
setSelectedUserIds([
|
||||
...Array.from(
|
||||
new Set([...selectedUserIds, option.value as string])
|
||||
),
|
||||
]);
|
||||
}}
|
||||
itemComponent={({ option }) => (
|
||||
<div className="flex px-4 py-2.5 cursor-pointer hover:bg-accent-background-hovered">
|
||||
<UsersIcon className="mr-2 my-auto" />
|
||||
{option.name}
|
||||
<div className="ml-auto my-auto">
|
||||
<FiPlus />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
.map((user) => ({
|
||||
label: user.email,
|
||||
value: user.id,
|
||||
}))}
|
||||
strict
|
||||
leftSearchIcon
|
||||
/>
|
||||
{onSubmit && (
|
||||
<Button
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
ChangeEvent,
|
||||
FC,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
JSX,
|
||||
} from "react";
|
||||
import { ChevronDownIcon } from "./icons/icons";
|
||||
import { forwardRef, useEffect, useRef, useState, JSX } from "react";
|
||||
import { FiCheck, FiChevronDown, FiInfo } from "react-icons/fi";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgPlus } from "@opal/icons";
|
||||
export interface Option<T> {
|
||||
name: string;
|
||||
value: T;
|
||||
@@ -28,230 +16,6 @@ export interface Option<T> {
|
||||
|
||||
export type StringOrNumberOption = Option<string | number>;
|
||||
|
||||
function StandardDropdownOption<T>({
|
||||
index,
|
||||
option,
|
||||
handleSelect,
|
||||
}: {
|
||||
index: number;
|
||||
option: Option<T>;
|
||||
handleSelect: (option: Option<T>) => void;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={() => handleSelect(option)}
|
||||
className={`w-full text-left block px-4 py-2.5 text-sm bg-white dark:bg-neutral-800 hover:bg-neutral-50 dark:hover:bg-neutral-700 ${
|
||||
index !== 0 ? "border-t border-neutral-200 dark:border-neutral-700" : ""
|
||||
}`}
|
||||
role="menuitem"
|
||||
>
|
||||
<p className="font-medium text-xs text-neutral-900 dark:text-neutral-100">
|
||||
{option.name}
|
||||
</p>
|
||||
{option.description && (
|
||||
<p className="text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{option.description}
|
||||
</p>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export function SearchMultiSelectDropdown({
|
||||
options,
|
||||
onSelect,
|
||||
itemComponent,
|
||||
onCreate,
|
||||
onDelete,
|
||||
onSearchTermChange,
|
||||
initialSearchTerm = "",
|
||||
allowCustomValues = false,
|
||||
}: {
|
||||
options: StringOrNumberOption[];
|
||||
onSelect: (selected: StringOrNumberOption) => void;
|
||||
itemComponent?: (props: { option: StringOrNumberOption }) => JSX.Element;
|
||||
onCreate?: (name: string) => void;
|
||||
onDelete?: (name: string) => void;
|
||||
onSearchTermChange?: (term: string) => void;
|
||||
initialSearchTerm?: string;
|
||||
allowCustomValues?: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState(initialSearchTerm);
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleSelect = (option: StringOrNumberOption) => {
|
||||
onSelect(option);
|
||||
setIsOpen(false);
|
||||
setSearchTerm(""); // Clear search term after selection
|
||||
};
|
||||
|
||||
const filteredOptions = options.filter((option) =>
|
||||
option.name.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
|
||||
// Handle selecting a custom value not in the options list
|
||||
const handleCustomValueSelect = () => {
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
const customOption: StringOrNumberOption = {
|
||||
name: searchTerm,
|
||||
value: searchTerm,
|
||||
};
|
||||
onSelect(customOption);
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
// If allowCustomValues is enabled and there's text in the search field,
|
||||
// treat clicking outside as selecting the custom value
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, [allowCustomValues, searchTerm]);
|
||||
|
||||
useEffect(() => {
|
||||
setSearchTerm(initialSearchTerm);
|
||||
}, [initialSearchTerm]);
|
||||
|
||||
return (
|
||||
<div className="relative text-left w-full" ref={dropdownRef}>
|
||||
<div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder={
|
||||
allowCustomValues ? "Search or enter custom value..." : "Search..."
|
||||
}
|
||||
value={searchTerm}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const newValue = e.target.value;
|
||||
setSearchTerm(newValue);
|
||||
if (onSearchTermChange) {
|
||||
onSearchTermChange(newValue);
|
||||
}
|
||||
if (newValue) {
|
||||
setIsOpen(true);
|
||||
} else {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}}
|
||||
onFocus={() => setIsOpen(true)}
|
||||
onKeyDown={(e) => {
|
||||
if (
|
||||
e.key === "Enter" &&
|
||||
allowCustomValues &&
|
||||
searchTerm.trim() !== ""
|
||||
) {
|
||||
e.preventDefault();
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
}}
|
||||
className="inline-flex justify-between w-full px-4 py-2 text-sm bg-white dark:bg-transparent text-neutral-800 dark:text-neutral-200 border border-neutral-200 dark:border-neutral-700 rounded-md shadow-sm"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
className="absolute top-0 right-0 text-sm h-full px-2 border-l border-neutral-200 dark:border-neutral-700"
|
||||
aria-expanded={isOpen}
|
||||
aria-haspopup="true"
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
<ChevronDownIcon className="my-auto w-4 h-4 text-neutral-600 dark:text-neutral-400" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isOpen && (
|
||||
<div className="absolute z-10 mt-1 w-full rounded-md shadow-lg bg-white dark:bg-neutral-900 border border-neutral-200 dark:border-neutral-700 max-h-60 overflow-y-auto">
|
||||
<div
|
||||
role="menu"
|
||||
aria-orientation="vertical"
|
||||
aria-labelledby="options-menu"
|
||||
>
|
||||
{filteredOptions.map((option, index) =>
|
||||
itemComponent ? (
|
||||
<div
|
||||
key={option.name}
|
||||
onClick={() => {
|
||||
handleSelect(option);
|
||||
}}
|
||||
>
|
||||
{itemComponent({ option })}
|
||||
</div>
|
||||
) : (
|
||||
<StandardDropdownOption
|
||||
key={index}
|
||||
option={option}
|
||||
index={index}
|
||||
handleSelect={handleSelect}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
|
||||
{allowCustomValues &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<Button
|
||||
className="w-full"
|
||||
role="menuitem"
|
||||
onClick={handleCustomValueSelect}
|
||||
leftIcon={SvgPlus}
|
||||
>
|
||||
Use "{searchTerm}" as custom value
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{onCreate &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<>
|
||||
<div className="border-t border-background-300"></div>
|
||||
<Button
|
||||
className="w-full"
|
||||
role="menuitem"
|
||||
onClick={() => {
|
||||
onCreate(searchTerm);
|
||||
setIsOpen(false);
|
||||
setSearchTerm("");
|
||||
}}
|
||||
leftIcon={SvgPlus}
|
||||
>
|
||||
Create label "{searchTerm}"
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
|
||||
{filteredOptions.length === 0 &&
|
||||
((!onCreate && !allowCustomValues) ||
|
||||
searchTerm.trim() === "") && (
|
||||
<div className="px-4 py-2.5 text-sm text-text-500">
|
||||
No matches found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const CustomDropdown = ({
|
||||
children,
|
||||
dropdown,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { FormikProps, ErrorMessage } from "formik";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgX } from "@opal/icons";
|
||||
export type GenericMultiSelectFormType<T extends string> = {
|
||||
@@ -84,15 +84,11 @@ export function GenericMultiSelect<
|
||||
const selectedIds = (formikProps.values[fieldName] as number[]) || [];
|
||||
const selectedItems = items.filter((item) => selectedIds.includes(item.id));
|
||||
|
||||
const handleSelect = (option: { name: string; value: string | number }) => {
|
||||
const handleSelect = (itemId: number) => {
|
||||
if (disabled) return;
|
||||
const currentIds = (formikProps.values[fieldName] as number[]) || [];
|
||||
const numValue =
|
||||
typeof option.value === "string"
|
||||
? parseInt(option.value, 10)
|
||||
: option.value;
|
||||
if (!currentIds.includes(numValue)) {
|
||||
formikProps.setFieldValue(fieldName, [...currentIds, numValue]);
|
||||
if (!currentIds.includes(itemId)) {
|
||||
formikProps.setFieldValue(fieldName, [...currentIds, itemId]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -118,14 +114,24 @@ export function GenericMultiSelect<
|
||||
)}
|
||||
|
||||
<div className={cn(disabled && "opacity-50 pointer-events-none")}>
|
||||
<SearchMultiSelectDropdown
|
||||
<InputComboBox
|
||||
placeholder="Search..."
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={(selectedValue) => {
|
||||
const numValue = parseInt(selectedValue, 10);
|
||||
if (!isNaN(numValue)) {
|
||||
handleSelect(numValue);
|
||||
}
|
||||
}}
|
||||
options={items
|
||||
.filter((item) => !selectedIds.includes(item.id))
|
||||
.map((item) => ({
|
||||
name: item.name,
|
||||
value: item.id,
|
||||
label: item.name,
|
||||
value: String(item.id),
|
||||
}))}
|
||||
onSelect={handleSelect}
|
||||
strict
|
||||
leftSearchIcon
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -619,6 +619,9 @@ export function FederatedConnectorForm({
|
||||
);
|
||||
}
|
||||
|
||||
const channelInputPlaceholder =
|
||||
"Type channel name or regex pattern and press Enter";
|
||||
|
||||
return (
|
||||
<>
|
||||
{Object.entries(formState.configurationSchema).map(
|
||||
@@ -688,9 +691,9 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
}}
|
||||
placeholder={
|
||||
fieldSpec.example &&
|
||||
Array.isArray(fieldSpec.example)
|
||||
? `e.g., ${fieldSpec.example.join(", ")}`
|
||||
fieldKey === "channels" ||
|
||||
fieldKey === "exclude_channels"
|
||||
? channelInputPlaceholder
|
||||
: "Type and press Enter to add an item"
|
||||
}
|
||||
disabled={disableSlackChannelInput(fieldKey)}
|
||||
|
||||
@@ -651,6 +651,7 @@ export function useLlmManager(
|
||||
|
||||
if (modelName) {
|
||||
const model = parseLlmDescriptor(modelName);
|
||||
// If we have no parsed modelName, try to find the provider by the raw modelName string
|
||||
if (!(model.modelName && model.modelName.length > 0)) {
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
@@ -666,14 +667,36 @@ export function useLlmManager(
|
||||
}
|
||||
}
|
||||
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
// If we have parsed provider info, try to find that specific provider.
|
||||
// This ensures we don't incorrectly match a model to the wrong provider
|
||||
// when the same model name exists across multiple providers (e.g., gpt-5 in Azure and OpenAI)
|
||||
if (model.provider && model.provider.length > 0) {
|
||||
const matchingProvider = llmProviders.find(
|
||||
(p) =>
|
||||
p.provider === model.provider &&
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
if (matchingProvider) {
|
||||
return {
|
||||
...model,
|
||||
name: matchingProvider.name,
|
||||
provider: matchingProvider.provider,
|
||||
};
|
||||
}
|
||||
// Provider info was present but not found - fall through to default
|
||||
} else {
|
||||
// Only search by model name when no provider info was parsed
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
|
||||
if (provider) {
|
||||
return { ...model, provider: provider.provider, name: provider.name };
|
||||
if (provider) {
|
||||
return { ...model, provider: provider.provider, name: provider.name };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,12 +734,17 @@ export function useLlmManager(
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number>(() => {
|
||||
llmUpdate();
|
||||
|
||||
if (currentChatSession?.current_temperature_override != null) {
|
||||
// Derive Anthropic check from chat session since currentLlm isn't populated yet
|
||||
const sessionModel = currentChatSession.current_alternate_model
|
||||
? parseLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
: null;
|
||||
const isAnthropicModel = sessionModel
|
||||
? isAnthropic(sessionModel.provider, sessionModel.modelName)
|
||||
: false;
|
||||
return Math.min(
|
||||
currentChatSession.current_temperature_override,
|
||||
isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0
|
||||
isAnthropicModel ? 1.0 : 2.0
|
||||
);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
@@ -727,8 +755,20 @@ export function useLlmManager(
|
||||
});
|
||||
|
||||
const maxTemperature = useMemo(() => {
|
||||
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
|
||||
}, [currentLlm]);
|
||||
// Check currentLlm first, fall back to chat session model if currentLlm isn't populated
|
||||
if (currentLlm.provider) {
|
||||
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
|
||||
}
|
||||
const sessionModel = currentChatSession?.current_alternate_model
|
||||
? parseLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
: null;
|
||||
if (sessionModel?.provider) {
|
||||
return isAnthropic(sessionModel.provider, sessionModel.modelName)
|
||||
? 1.0
|
||||
: 2.0;
|
||||
}
|
||||
return 2.0; // Default max when no model info available
|
||||
}, [currentLlm, currentChatSession]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
@@ -770,13 +810,12 @@ export function useLlmManager(
|
||||
]);
|
||||
|
||||
const updateTemperature = (temperature: number) => {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
setTemperature(Math.min(temperature, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
}
|
||||
const clampedTemp = isAnthropic(currentLlm.provider, currentLlm.modelName)
|
||||
? Math.min(temperature, 1.0)
|
||||
: temperature;
|
||||
setTemperature(clampedTemp);
|
||||
if (chatSession) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, temperature);
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, clampedTemp);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1029,12 +1068,16 @@ export function useSourcePreferences({
|
||||
}
|
||||
setSourcesInitialized(true);
|
||||
}
|
||||
}, [availableSources, sourcesInitialized, setSelectedSources]);
|
||||
}, [
|
||||
availableSources,
|
||||
configuredSources,
|
||||
sourcesInitialized,
|
||||
setSelectedSources,
|
||||
]);
|
||||
|
||||
const enableSources = (sources: SourceMetadata[]) => {
|
||||
const allSourceMetadata = getConfiguredSources(availableSources);
|
||||
setSelectedSources([...sources]);
|
||||
persistSourcePreferencesState(sources, allSourceMetadata);
|
||||
persistSourcePreferencesState(sources, configuredSources);
|
||||
};
|
||||
|
||||
const enableAllSources = () => {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export function useProjects() {
|
||||
const { data, error, mutate } = useSWR<Project[]>(
|
||||
"/api/user/projects/",
|
||||
"/api/user/projects",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
|
||||
@@ -13,7 +13,7 @@ const ConnectionProviderIcon = memo(({ icon }: ConnectionProviderIconProps) => {
|
||||
<SvgArrowExchange className="w-3 h-3 stroke-text-04" />
|
||||
</div>
|
||||
<div className="w-7 h-7 flex items-center justify-center">
|
||||
<SvgOnyxLogo width={24} height={24} className="fill-text-04" />
|
||||
<SvgOnyxLogo size={24} className="fill-text-04" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -353,7 +353,15 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
>
|
||||
<Icon className="w-[1.5rem] h-[1.5rem] stroke-text-04" />
|
||||
{/*
|
||||
The `h-[1.5rem]` and `w-[1.5rem]` were added as backups here.
|
||||
However, prop-resolution technically resolves to choosing classNames over size props, so technically the `size={24}` is the backup.
|
||||
We specify both to be safe.
|
||||
|
||||
# Note
|
||||
1.5rem === 24px
|
||||
*/}
|
||||
<Icon className="stroke-text-04 h-[1.5rem] w-[1.5rem]" size={24} />
|
||||
{onClose && (
|
||||
<div
|
||||
tabIndex={-1}
|
||||
|
||||
@@ -982,9 +982,10 @@ export default function AgentEditorPage({
|
||||
const hasUploadingFiles = values.user_file_ids.some(
|
||||
(fileId: string) => {
|
||||
const status = fileStatusMap.get(fileId);
|
||||
return (
|
||||
status === undefined || status === UserFileStatus.UPLOADING
|
||||
);
|
||||
if (status === undefined) {
|
||||
return fileId.startsWith("temp_");
|
||||
}
|
||||
return status === UserFileStatus.UPLOADING;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user