mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 11:45:47 +00:00
Compare commits
26 Commits
ci_python_
...
v2.11.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de |
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -85,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -361,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -341,6 +341,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -2933,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -32,6 +32,7 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
@@ -136,6 +137,7 @@ class RedisConnectorDelete:
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -45,6 +45,7 @@ class RedisConnectorPrune:
|
||||
) # connectorpruning_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -184,6 +185,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
self.redis.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
|
||||
@@ -23,6 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -83,6 +84,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -24,6 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
@@ -97,6 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
redis_client.expire(self.taskset_key, self.TASKSET_TTL)
|
||||
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -84,7 +84,8 @@ def patch_document_set(
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
@@ -125,7 +126,8 @@ def delete_document_set(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
object_is_public=document_set.is_public,
|
||||
object_is_owned_by_user=user and document_set.user_id == user.id,
|
||||
object_is_owned_by_user=user
|
||||
and (document_set.user_id is None or document_set.user_id == user.id),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
@router.get("/", tags=PUBLIC_API_TAGS)
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DEV_VERSION_PATTERN
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import STABLE_VERSION_PATTERN
|
||||
@@ -30,13 +31,20 @@ def healthcheck() -> StatusResponse:
|
||||
|
||||
@router.get("/auth/type", tags=PUBLIC_API_TAGS)
|
||||
async def get_auth_type() -> AuthTypeResponse:
|
||||
user_count = await get_user_count()
|
||||
# NOTE: This endpoint is critical for the multi-tenant flow and is hit before there is a tenant context
|
||||
# The reason is this is used during the login flow, but we don't know which tenant the user is supposed to be
|
||||
# associated with until they auth.
|
||||
has_users = True
|
||||
if AUTH_TYPE != AuthType.CLOUD:
|
||||
user_count = await get_user_count()
|
||||
has_users = user_count > 0
|
||||
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
password_min_length=PASSWORD_MIN_LENGTH,
|
||||
has_users=user_count > 0,
|
||||
has_users=has_users,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
@@ -31,7 +31,7 @@ class ProjectManager:
|
||||
) -> List[UserProjectSnapshot]:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -56,7 +56,7 @@ class ProjectManager:
|
||||
) -> bool:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/",
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Unit tests for _yield_doc_batches and metadata type conversion in SalesforceConnector."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import _convert_to_metadata_value
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
|
||||
class TestConvertToMetadataValue:
|
||||
"""Tests for the _convert_to_metadata_value helper function."""
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""String values should be returned as-is."""
|
||||
assert _convert_to_metadata_value("hello") == "hello"
|
||||
assert _convert_to_metadata_value("") == ""
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
"""Boolean True should be converted to string 'True'."""
|
||||
assert _convert_to_metadata_value(True) == "True"
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
"""Boolean False should be converted to string 'False'."""
|
||||
assert _convert_to_metadata_value(False) == "False"
|
||||
|
||||
def test_integer_value(self) -> None:
|
||||
"""Integer values should be converted to string."""
|
||||
assert _convert_to_metadata_value(42) == "42"
|
||||
assert _convert_to_metadata_value(0) == "0"
|
||||
assert _convert_to_metadata_value(-100) == "-100"
|
||||
|
||||
def test_float_value(self) -> None:
|
||||
"""Float values should be converted to string."""
|
||||
assert _convert_to_metadata_value(3.14) == "3.14"
|
||||
assert _convert_to_metadata_value(0.0) == "0.0"
|
||||
assert _convert_to_metadata_value(-2.5) == "-2.5"
|
||||
|
||||
def test_list_of_strings(self) -> None:
|
||||
"""List of strings should remain as list of strings."""
|
||||
result = _convert_to_metadata_value(["a", "b", "c"])
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_list_of_mixed_types(self) -> None:
|
||||
"""List with mixed types should have all items converted to strings."""
|
||||
result = _convert_to_metadata_value([1, True, 3.14, "text"])
|
||||
assert result == ["1", "True", "3.14", "text"]
|
||||
|
||||
def test_empty_list(self) -> None:
|
||||
"""Empty list should return empty list."""
|
||||
assert _convert_to_metadata_value([]) == []
|
||||
|
||||
|
||||
class TestYieldDocBatches:
|
||||
"""Tests for the _yield_doc_batches method of SalesforceConnector."""
|
||||
|
||||
@pytest.fixture
|
||||
def connector(self) -> SalesforceConnector:
|
||||
"""Create a SalesforceConnector instance with mocked sf_client."""
|
||||
connector = SalesforceConnector(
|
||||
batch_size=10,
|
||||
requested_objects=["Opportunity"],
|
||||
)
|
||||
# Mock the sf_client property
|
||||
mock_sf_client = MagicMock()
|
||||
mock_sf_client.sf_instance = "test.salesforce.com"
|
||||
connector._sf_client = mock_sf_client
|
||||
return connector
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sf_db(self) -> MagicMock:
|
||||
"""Create a mock OnyxSalesforceSQLite object."""
|
||||
return MagicMock()
|
||||
|
||||
def _create_salesforce_object(
|
||||
self,
|
||||
object_id: str,
|
||||
object_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> SalesforceObject:
|
||||
"""Helper to create a SalesforceObject with required fields."""
|
||||
# Ensure required fields are present
|
||||
data.setdefault(ID_FIELD, object_id)
|
||||
data.setdefault(MODIFIED_FIELD, "2024-01-15T10:30:00.000Z")
|
||||
data.setdefault(NAME_FIELD, f"Test {object_type}")
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_metadata_type_conversion_for_opportunity(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that Opportunity metadata fields are properly type-converted."""
|
||||
parent_id = "006bm000006kyDpAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create a parent object with various data types in the fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Test Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"Account": "Acme Corp", # string - should become "account" metadata
|
||||
"FiscalQuarter": 2, # int - should be converted to "2"
|
||||
"FiscalYear": 2024, # int - should be converted to "2024"
|
||||
"IsClosed": False, # bool - should be converted to "False"
|
||||
"StageName": "Prospecting", # string
|
||||
"Type": "New Business", # string
|
||||
"Amount": 50000.50, # float - should be converted to "50000.50"
|
||||
"CloseDate": "2024-06-30", # string
|
||||
"Probability": 75, # int - should be converted to "75"
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z", # string
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
# Setup mock sf_db
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create a mock document that convert_sf_object_to_doc will return
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Test Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
# Track parent changes
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# Call _yield_doc_batches
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify we got one batch with one document
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
assert len(docs) == 1
|
||||
|
||||
doc = docs[0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify metadata type conversions
|
||||
# All values should be strings (or list of strings)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["account"] == "Acme Corp" # string stays string
|
||||
assert doc.metadata["fiscal_quarter"] == "2" # int -> str
|
||||
assert doc.metadata["fiscal_year"] == "2024" # int -> str
|
||||
assert doc.metadata["is_closed"] == "False" # bool -> str
|
||||
assert doc.metadata["stage_name"] == "Prospecting" # string stays string
|
||||
assert doc.metadata["type"] == "New Business" # string stays string
|
||||
assert (
|
||||
doc.metadata["amount"] == "50000.5"
|
||||
) # float -> str (Python drops trailing zeros)
|
||||
assert doc.metadata["close_date"] == "2024-06-30" # string stays string
|
||||
assert doc.metadata["probability"] == "75" # int -> str
|
||||
assert doc.metadata["name"] == "Test Opportunity" # NAME_FIELD
|
||||
|
||||
# Verify parent was counted
|
||||
assert parents_changed == 1
|
||||
assert type_to_processed[parent_type] == 1
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_missing_optional_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing optional metadata fields are not added."""
|
||||
parent_id = "006bm000006kyDqAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# Create parent object with only some fields
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Minimal Opportunity",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"StageName": "Closed Won",
|
||||
# Notably missing: Amount, Probability, FiscalQuarter, etc.
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Minimal Opportunity",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only present fields should be in metadata
|
||||
assert "stage_name" in doc.metadata
|
||||
assert doc.metadata["stage_name"] == "Closed Won"
|
||||
assert "name" in doc.metadata
|
||||
assert doc.metadata["name"] == "Minimal Opportunity"
|
||||
|
||||
# Missing fields should not be in metadata
|
||||
assert "amount" not in doc.metadata
|
||||
assert "probability" not in doc.metadata
|
||||
assert "fiscal_quarter" not in doc.metadata
|
||||
assert "fiscal_year" not in doc.metadata
|
||||
assert "is_closed" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_contact_metadata_fields(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test metadata conversion for Contact object type."""
|
||||
parent_id = "003bm00000EjHCjAAN"
|
||||
parent_type = "Contact"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "John Doe",
|
||||
MODIFIED_FIELD: "2024-02-20T14:00:00.000Z",
|
||||
"Account": "Globex Corp",
|
||||
"CreatedDate": "2024-01-01T00:00:00.000Z",
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="John Doe",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Verify Contact-specific metadata
|
||||
assert doc.metadata["object_type"] == "Contact"
|
||||
assert doc.metadata["account"] == "Globex Corp"
|
||||
assert doc.metadata["created_date"] == "2024-01-01T00:00:00.000Z"
|
||||
assert doc.metadata["last_modified_date"] == "2024-02-20T14:00:00.000Z"
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_no_default_attributes_for_unknown_type(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that unknown object types only get object_type metadata."""
|
||||
parent_id = "001bm00000fd9Z3AAI"
|
||||
parent_type = "CustomObject__c"
|
||||
|
||||
parent_data = {
|
||||
ID_FIELD: parent_id,
|
||||
NAME_FIELD: "Custom Record",
|
||||
MODIFIED_FIELD: "2024-03-01T08:00:00.000Z",
|
||||
"CustomField__c": "custom value",
|
||||
"NumberField__c": 123,
|
||||
}
|
||||
parent_object = self._create_salesforce_object(
|
||||
parent_id, parent_type, parent_data
|
||||
)
|
||||
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = parent_object
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
mock_doc = Document(
|
||||
id=f"SALESFORCE_{parent_id}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="Custom Record",
|
||||
metadata={},
|
||||
)
|
||||
mock_convert.return_value = mock_doc
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
doc = batches[0][0]
|
||||
assert isinstance(doc, Document)
|
||||
|
||||
# Only object_type should be set for unknown types
|
||||
assert doc.metadata["object_type"] == "CustomObject__c"
|
||||
# Custom fields should NOT be in metadata (not in _DEFAULT_ATTRIBUTES_TO_KEEP)
|
||||
assert "CustomField__c" not in doc.metadata
|
||||
assert "NumberField__c" not in doc.metadata
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_skips_missing_parent_objects(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that missing parent objects are skipped gracefully."""
|
||||
parent_id = "006bm000006kyDrAAI"
|
||||
parent_type = "Opportunity"
|
||||
|
||||
# get_record returns None for missing object
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, parent_id, 1)]
|
||||
)
|
||||
mock_sf_db.get_record.return_value = None
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {parent_id: parent_type}
|
||||
parent_types = {parent_type}
|
||||
|
||||
parents_changed = 0
|
||||
|
||||
def increment() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
increment,
|
||||
)
|
||||
)
|
||||
|
||||
# Should yield one empty batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 0
|
||||
|
||||
# convert_sf_object_to_doc should not have been called
|
||||
mock_convert.assert_not_called()
|
||||
|
||||
# Parents changed should still be 0
|
||||
assert parents_changed == 0
|
||||
|
||||
@patch("onyx.connectors.salesforce.connector.convert_sf_object_to_doc")
|
||||
def test_multiple_documents_batching(
|
||||
self,
|
||||
mock_convert: MagicMock,
|
||||
connector: SalesforceConnector,
|
||||
mock_sf_db: MagicMock,
|
||||
) -> None:
|
||||
"""Test that multiple documents are correctly batched."""
|
||||
# Create 3 parent objects
|
||||
parent_ids = [
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
]
|
||||
parent_type = "Opportunity"
|
||||
|
||||
parent_objects = [
|
||||
self._create_salesforce_object(
|
||||
pid,
|
||||
parent_type,
|
||||
{
|
||||
ID_FIELD: pid,
|
||||
NAME_FIELD: f"Opportunity {i}",
|
||||
MODIFIED_FIELD: "2024-01-15T10:30:00.000Z",
|
||||
"IsClosed": i % 2 == 0, # alternating bool values
|
||||
"Amount": 1000.0 * (i + 1),
|
||||
},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
|
||||
# Setup mock to return all three
|
||||
mock_sf_db.get_changed_parent_ids_by_type.return_value = iter(
|
||||
[(parent_type, pid, i + 1) for i, pid in enumerate(parent_ids)]
|
||||
)
|
||||
mock_sf_db.get_record.side_effect = parent_objects
|
||||
mock_sf_db.file_size = 1024
|
||||
|
||||
# Create mock documents
|
||||
mock_docs = [
|
||||
Document(
|
||||
id=f"SALESFORCE_{pid}",
|
||||
sections=[],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=f"Opportunity {i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, pid in enumerate(parent_ids)
|
||||
]
|
||||
mock_convert.side_effect = mock_docs
|
||||
|
||||
type_to_processed: dict[str, int] = {}
|
||||
changed_ids_to_type = {pid: parent_type for pid in parent_ids}
|
||||
parent_types = {parent_type}
|
||||
|
||||
batches = list(
|
||||
connector._yield_doc_batches(
|
||||
mock_sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
parent_types,
|
||||
lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
# With batch_size=10, all 3 docs should be in one batch
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 3
|
||||
|
||||
# Verify each document has correct metadata
|
||||
for i, doc in enumerate(batches[0]):
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.metadata["object_type"] == "Opportunity"
|
||||
assert doc.metadata["is_closed"] == str(i % 2 == 0)
|
||||
assert doc.metadata["amount"] == str(1000.0 * (i + 1))
|
||||
|
||||
assert type_to_processed[parent_type] == 3
|
||||
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
135
backend/tests/unit/onyx/db/test_delete_user.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
|
||||
|
||||
def _mock_user(
|
||||
user_id: UUID | None = None, email: str = "test@example.com"
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.oauth_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
def _make_query_chain() -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).delete() and .filter(...).update(...)"""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
return chain
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_nulls_out_document_set_ownership(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Verify DocumentSet.user_id is nulled out (update, not delete)
|
||||
doc_set_chain = query_chains[DocumentSet]
|
||||
doc_set_chain.filter.assert_called()
|
||||
doc_set_chain.filter.return_value.update.assert_called_once_with(
|
||||
{DocumentSet.user_id: None}
|
||||
)
|
||||
|
||||
# Verify Persona.user_id is nulled out (update, not delete)
|
||||
persona_chain = query_chains[Persona]
|
||||
persona_chain.filter.assert_called()
|
||||
persona_chain.filter.return_value.update.assert_called_once_with(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_cleans_up_join_tables(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
db_session = MagicMock()
|
||||
|
||||
query_chains: dict[type, MagicMock] = {}
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model not in query_chains:
|
||||
query_chains[model] = _make_query_chain()
|
||||
return query_chains[model]
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
# Join tables should be deleted (not updated)
|
||||
for model in [DocumentSet__User, Persona__User, User__UserGroup, SamlAccount]:
|
||||
chain = query_chains[model]
|
||||
chain.filter.return_value.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_commits_and_removes_invited(
|
||||
_mock_ee: Any, mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user(email="deleted@example.com")
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_called_once_with(user)
|
||||
db_session.commit.assert_called_once()
|
||||
mock_remove_invited.assert_called_once_with("deleted@example.com")
|
||||
|
||||
|
||||
@patch("onyx.db.users.remove_user_from_invited_users")
|
||||
@patch(
|
||||
"onyx.db.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda **_kwargs: None,
|
||||
)
|
||||
def test_delete_user_deletes_oauth_accounts(
|
||||
_mock_ee: Any, _mock_remove_invited: Any
|
||||
) -> None:
|
||||
user = _mock_user()
|
||||
oauth1 = MagicMock()
|
||||
oauth2 = MagicMock()
|
||||
user.oauth_accounts = [oauth1, oauth2]
|
||||
db_session = MagicMock()
|
||||
db_session.query.return_value = _make_query_chain()
|
||||
|
||||
delete_user_from_db(user, db_session)
|
||||
|
||||
db_session.delete.assert_any_call(oauth1)
|
||||
db_session.delete.assert_any_call(oauth2)
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
3
desktop/.gitignore
vendored
3
desktop/.gitignore
vendored
@@ -22,3 +22,6 @@ npm-debug.log*
|
||||
# Local env files
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Generated files
|
||||
src-tauri/gen/schemas/acl-manifests.json
|
||||
|
||||
96
desktop/src-tauri/Cargo.lock
generated
96
desktop/src-tauri/Cargo.lock
generated
@@ -706,16 +706,6 @@ dependencies = [
|
||||
"typeid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -993,16 +983,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -1122,24 +1102,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "global-hotkey"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9247516746aa8e53411a0db9b62b0e24efbcf6a76e0ba73e5a91b512ddabed7"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"keyboard-types",
|
||||
"objc2 0.6.3",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"windows-sys 0.59.0",
|
||||
"x11rb",
|
||||
"xkeysym",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -1713,12 +1675,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
@@ -2248,7 +2204,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-global-shortcut",
|
||||
"tauri-plugin-shell",
|
||||
"tauri-plugin-window-state",
|
||||
"tokio",
|
||||
@@ -2878,19 +2833,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
@@ -3605,21 +3547,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-global-shortcut"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "424af23c7e88d05e4a1a6fc2c7be077912f8c76bd7900fd50aa2b7cbf5a2c405"
|
||||
dependencies = [
|
||||
"global-hotkey",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-shell"
|
||||
version = "2.3.3"
|
||||
@@ -5021,29 +4948,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xkeysym"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
|
||||
@@ -11,7 +11,6 @@ tauri-build = { version = "2.0", features = [] }
|
||||
[dependencies]
|
||||
tauri = { version = "2.0", features = ["macos-private-api", "tray-icon", "image-png"] }
|
||||
tauri-plugin-shell = "2.0"
|
||||
tauri-plugin-global-shortcut = "2.0"
|
||||
tauri-plugin-window-state = "2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -2354,72 +2354,6 @@
|
||||
"const": "core:window:deny-unminimize",
|
||||
"markdownDescription": "Denies the unminimize command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:default",
|
||||
"markdownDescription": "No features are enabled by default, as we believe\nthe shortcuts can be inherently dangerous and it is\napplication specific if specific shortcuts should be\nregistered or unregistered.\n"
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-is-registered",
|
||||
"markdownDescription": "Enables the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register",
|
||||
"markdownDescription": "Enables the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-register-all",
|
||||
"markdownDescription": "Enables the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister",
|
||||
"markdownDescription": "Enables the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:allow-unregister-all",
|
||||
"markdownDescription": "Enables the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_registered command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-is-registered",
|
||||
"markdownDescription": "Denies the is_registered command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register",
|
||||
"markdownDescription": "Denies the register command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the register_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-register-all",
|
||||
"markdownDescription": "Denies the register_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister",
|
||||
"markdownDescription": "Denies the unregister command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unregister_all command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "global-shortcut:deny-unregister-all",
|
||||
"markdownDescription": "Denies the unregister_all command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`",
|
||||
"type": "string",
|
||||
|
||||
@@ -20,7 +20,6 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
@@ -448,73 +447,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shortcuts Setup
|
||||
// ============================================================================
|
||||
|
||||
fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let new_chat = Shortcut::new(Some(Modifiers::SUPER), Code::KeyN);
|
||||
let reload = Shortcut::new(Some(Modifiers::SUPER), Code::KeyR);
|
||||
let back = Shortcut::new(Some(Modifiers::SUPER), Code::BracketLeft);
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
// Avoid hijacking the system-wide Cmd+R on macOS.
|
||||
#[cfg(target_os = "macos")]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let shortcuts = [
|
||||
new_chat,
|
||||
reload,
|
||||
back,
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
shortcuts,
|
||||
move |_app, shortcut, _event| {
|
||||
if shortcut == &new_chat {
|
||||
trigger_new_chat(&app_handle);
|
||||
}
|
||||
|
||||
if let Some(window) = app_handle.get_webview_window("main") {
|
||||
if shortcut == &reload {
|
||||
let _ = window.eval("window.location.reload()");
|
||||
} else if shortcut == &back {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if shortcut == &new_window_shortcut {
|
||||
trigger_new_window(&app_handle);
|
||||
} else if shortcut == &show_app {
|
||||
focus_main_window(&app_handle);
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -574,7 +506,7 @@ fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+Space"),
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
@@ -666,7 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -698,11 +629,6 @@ fn main() {
|
||||
.setup(move |app| {
|
||||
let app_handle = app.handle();
|
||||
|
||||
// Setup global shortcuts
|
||||
if let Err(e) = setup_shortcuts(&app_handle) {
|
||||
eprintln!("Failed to setup shortcuts: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = setup_app_menu(&app_handle) {
|
||||
eprintln!("Failed to setup menu: {}", e);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background-900: #1a1a1a;
|
||||
--background-800: #262626;
|
||||
--text-light-05: rgba(255, 255, 255, 0.95);
|
||||
--text-light-03: rgba(255, 255, 255, 0.6);
|
||||
--white-10: rgba(255, 255, 255, 0.08);
|
||||
--white-15: rgba(255, 255, 255, 0.12);
|
||||
--white-20: rgba(255, 255, 255, 0.15);
|
||||
--white-30: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
@@ -30,7 +41,11 @@
|
||||
|
||||
body {
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--background-900) 0%,
|
||||
var(--background-800) 100%
|
||||
);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
@@ -39,6 +54,9 @@
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
color 0.3s ease;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
@@ -69,16 +87,19 @@
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
background: var(--background-800);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition:
|
||||
background 0.3s ease,
|
||||
border 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .settings-panel {
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
@@ -93,17 +114,19 @@
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
background: var(--background-900);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
overflow: hidden;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
color: var(--text-light-05);
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
@@ -134,9 +157,10 @@
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
background: var(--background-900);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
@@ -176,7 +200,7 @@
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-800);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
@@ -186,8 +210,8 @@
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
background: var(--background-900);
|
||||
box-shadow: 0 0 0 2px var(--white-10);
|
||||
}
|
||||
|
||||
.input-field::placeholder {
|
||||
@@ -231,7 +255,7 @@
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
background-color: var(--white-15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
@@ -243,14 +267,18 @@
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
background-color: var(--background-800);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.dark .toggle-slider:before {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
background-color: var(--white-30);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
@@ -288,14 +316,15 @@
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
background: var(--white-10);
|
||||
border: 1px solid var(--white-15);
|
||||
border-radius: 4px;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
@@ -372,10 +401,34 @@
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
// Theme detection based on system preferences
|
||||
function applySystemTheme() {
|
||||
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
||||
|
||||
function updateTheme(e) {
|
||||
if (e.matches) {
|
||||
document.documentElement.classList.add("dark");
|
||||
document.body.classList.add("dark");
|
||||
} else {
|
||||
document.documentElement.classList.remove("dark");
|
||||
document.body.classList.remove("dark");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply initial theme
|
||||
updateTheme(darkModeQuery);
|
||||
|
||||
// Listen for changes
|
||||
darkModeQuery.addEventListener("change", updateTheme);
|
||||
}
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Apply system theme immediately
|
||||
applySystemTheme();
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
|
||||
@@ -113,6 +113,23 @@
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function updateTitleBarTheme(isDark) {
|
||||
const titleBar = document.getElementById(TITLEBAR_ID);
|
||||
if (!titleBar) return;
|
||||
|
||||
if (isDark) {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)";
|
||||
} else {
|
||||
titleBar.style.background =
|
||||
"linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)";
|
||||
titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)";
|
||||
titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)";
|
||||
}
|
||||
}
|
||||
|
||||
function buildTitleBar() {
|
||||
const titleBar = document.createElement("div");
|
||||
titleBar.id = TITLEBAR_ID;
|
||||
@@ -134,6 +151,11 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Apply initial styles matching current theme
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
// Apply styles matching Onyx design system with translucent glass effect
|
||||
titleBar.style.cssText = `
|
||||
position: fixed;
|
||||
@@ -156,8 +178,12 @@
|
||||
-webkit-backdrop-filter: blur(18px) saturate(180%);
|
||||
-webkit-app-region: drag;
|
||||
padding: 0 12px;
|
||||
transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease;
|
||||
`;
|
||||
|
||||
// Apply correct theme
|
||||
updateTitleBarTheme(isDark);
|
||||
|
||||
return titleBar;
|
||||
}
|
||||
|
||||
@@ -168,6 +194,11 @@
|
||||
|
||||
const existing = document.getElementById(TITLEBAR_ID);
|
||||
if (existing?.parentElement === document.body) {
|
||||
// Update theme on existing titlebar
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,6 +209,14 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
|
||||
// Ensure theme is applied immediately after mount
|
||||
setTimeout(() => {
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
@@ -194,9 +233,66 @@
|
||||
}
|
||||
}
|
||||
|
||||
function observeThemeChanges() {
|
||||
let lastKnownTheme = null;
|
||||
|
||||
function checkAndUpdateTheme() {
|
||||
// Check both html and body for dark class (some apps use body)
|
||||
const htmlHasDark = document.documentElement.classList.contains("dark");
|
||||
const bodyHasDark = document.body?.classList.contains("dark");
|
||||
const isDark = htmlHasDark || bodyHasDark;
|
||||
|
||||
if (lastKnownTheme !== isDark) {
|
||||
lastKnownTheme = isDark;
|
||||
updateTitleBarTheme(isDark);
|
||||
}
|
||||
}
|
||||
|
||||
// Immediate check on setup
|
||||
checkAndUpdateTheme();
|
||||
|
||||
// Watch for theme changes on the HTML element
|
||||
const themeObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
|
||||
themeObserver.observe(document.documentElement, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
|
||||
// Also observe body if it exists
|
||||
if (document.body) {
|
||||
const bodyObserver = new MutationObserver(() => {
|
||||
checkAndUpdateTheme();
|
||||
});
|
||||
bodyObserver.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ["class"],
|
||||
});
|
||||
}
|
||||
|
||||
// Also check periodically in case classList is manipulated directly
|
||||
// or the theme loads asynchronously after page load
|
||||
const intervalId = setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 300);
|
||||
|
||||
// Clean up after 30 seconds once theme should be stable
|
||||
setTimeout(() => {
|
||||
clearInterval(intervalId);
|
||||
// But keep checking every 2 seconds for manual theme changes
|
||||
setInterval(() => {
|
||||
checkAndUpdateTheme();
|
||||
}, 2000);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
function init() {
|
||||
mountTitleBar();
|
||||
syncViewportHeight();
|
||||
observeThemeChanges();
|
||||
|
||||
window.addEventListener("resize", syncViewportHeight, { passive: true });
|
||||
window.visualViewport?.addEventListener("resize", syncViewportHeight, {
|
||||
passive: true,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowDownDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowDownDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowLeftDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowLeftDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowRightDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowRightDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 14 9"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgArrowUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgArrowUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { SVGProps } from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBracketCurly = (props: SVGProps<SVGSVGElement>) => (
|
||||
const SvgBracketCurly = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 15 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgClaude = (props: IconProps) => {
|
||||
const SvgClaude = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgClipboard = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgClipboard = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgCornerRightUpDot = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgCornerRightUpDot = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 9 14"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const OnyxLogo = ({
|
||||
width = 24,
|
||||
height = 24,
|
||||
className,
|
||||
...props
|
||||
}: IconProps) => (
|
||||
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={width}
|
||||
height={height}
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
@@ -23,4 +17,4 @@ const OnyxLogo = ({
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default OnyxLogo;
|
||||
export default SvgOnyxLogo;
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import React from "react";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgOpenAI = (props: IconProps) => {
|
||||
const SvgOpenAI = ({ size, ...props }: IconProps) => {
|
||||
const clipId = React.useId();
|
||||
return (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgTwoLineSmall = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgTwoLineSmall = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgUserPlus = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { SVGProps } from "react";
|
||||
const SvgWallet = (props: SVGProps<SVGSVGElement>) => (
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgWallet = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -300,13 +300,7 @@ export default function Page() {
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={
|
||||
<SvgOnyxLogo
|
||||
width={32}
|
||||
height={32}
|
||||
className="my-auto stroke-text-04"
|
||||
/>
|
||||
}
|
||||
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</>
|
||||
|
||||
@@ -131,10 +131,15 @@ export function CustomForm({
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(
|
||||
values.custom_config_list
|
||||
),
|
||||
|
||||
@@ -39,6 +39,8 @@ interface OllamaFormValues extends BaseLLMFormValues {
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
mutate: () => void;
|
||||
@@ -49,15 +51,14 @@ interface OllamaFormContentProps {
|
||||
function OllamaFormContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
const [availableModels, setAvailableModels] = useState<ModelConfiguration[]>(
|
||||
existingLlmProvider?.model_configurations || []
|
||||
);
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -70,16 +71,25 @@ function OllamaFormContent({
|
||||
.then((data) => {
|
||||
if (data.error) {
|
||||
console.error("Error fetching models:", data.error);
|
||||
setAvailableModels([]);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setAvailableModels(data.models);
|
||||
setFetchedModels(data.models);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingModels(false);
|
||||
});
|
||||
}
|
||||
}, [formikProps.values.api_base]);
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
existingLlmProvider?.name,
|
||||
setFetchedModels,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
@@ -99,7 +109,7 @@ function OllamaFormContent({
|
||||
/>
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={availableModels}
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
|
||||
isLoading={isLoadingModels}
|
||||
@@ -125,6 +135,8 @@ export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
@@ -189,7 +201,10 @@ export function OllamaForm({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
@@ -205,6 +220,8 @@ export function OllamaForm({
|
||||
<OllamaFormContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
testError={testError}
|
||||
mutate={mutate}
|
||||
|
||||
@@ -68,11 +68,7 @@ export const WebProviderSetupModal = memo(
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo
|
||||
width={24}
|
||||
height={24}
|
||||
className="text-text-04 shrink-0"
|
||||
/>
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1168,7 +1168,7 @@ export default function Page() {
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={16} height={16} />
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
@@ -1381,7 +1381,7 @@ export default function Page() {
|
||||
} logo`,
|
||||
fallback:
|
||||
selectedContentProviderType === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo width={24} height={24} className="text-text-05" />
|
||||
<SvgOnyxLogo size={24} className="text-text-05" />
|
||||
) : undefined,
|
||||
size: 24,
|
||||
containerSize: 28,
|
||||
|
||||
@@ -455,9 +455,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
<BackButton />
|
||||
<div
|
||||
className="flex
|
||||
items-center
|
||||
|
||||
@@ -25,7 +25,6 @@ import { useDocumentSets } from "@/lib/hooks/useDocumentSets";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { ChatPopup } from "@/app/chat/components/ChatPopup";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/components/tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextView from "@/components/chat/TextView";
|
||||
@@ -382,9 +381,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
if (liveAssistant) {
|
||||
return liveAssistant.tools.some(
|
||||
(tool) => tool.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
return personaIncludesRetrieval(liveAssistant);
|
||||
}
|
||||
return false;
|
||||
}, [liveAssistant]);
|
||||
|
||||
@@ -63,7 +63,7 @@ export type ProjectDetails = {
|
||||
};
|
||||
|
||||
export async function fetchProjects(): Promise<Project[]> {
|
||||
const response = await fetch("/api/user/projects/");
|
||||
const response = await fetch("/api/user/projects");
|
||||
if (!response.ok) {
|
||||
handleRequestError("Fetch projects", response);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
.card {
|
||||
@apply rounded-16 w-full;
|
||||
@apply rounded-16 w-full overflow-clip;
|
||||
}
|
||||
|
||||
.card[data-variant="primary"] {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { User } from "@/lib/types";
|
||||
import { FiPlus, FiX } from "react-icons/fi";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import { UsersIcon } from "@/components/icons/icons";
|
||||
import { FiX } from "react-icons/fi";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
interface UserEditorProps {
|
||||
@@ -51,35 +50,27 @@ export const UserEditor = ({
|
||||
</div>
|
||||
|
||||
<div className="flex">
|
||||
<SearchMultiSelectDropdown
|
||||
<InputComboBox
|
||||
placeholder="Search..."
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={(selectedValue) => {
|
||||
setSelectedUserIds([
|
||||
...Array.from(new Set([...selectedUserIds, selectedValue])),
|
||||
]);
|
||||
}}
|
||||
options={allUsers
|
||||
.filter(
|
||||
(user) =>
|
||||
!selectedUserIds.includes(user.id) &&
|
||||
!existingUsers.map((user) => user.id).includes(user.id)
|
||||
)
|
||||
.map((user) => {
|
||||
return {
|
||||
name: user.email,
|
||||
value: user.id,
|
||||
};
|
||||
})}
|
||||
onSelect={(option) => {
|
||||
setSelectedUserIds([
|
||||
...Array.from(
|
||||
new Set([...selectedUserIds, option.value as string])
|
||||
),
|
||||
]);
|
||||
}}
|
||||
itemComponent={({ option }) => (
|
||||
<div className="flex px-4 py-2.5 cursor-pointer hover:bg-accent-background-hovered">
|
||||
<UsersIcon className="mr-2 my-auto" />
|
||||
{option.name}
|
||||
<div className="ml-auto my-auto">
|
||||
<FiPlus />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
.map((user) => ({
|
||||
label: user.email,
|
||||
value: user.id,
|
||||
}))}
|
||||
strict
|
||||
leftSearchIcon
|
||||
/>
|
||||
{onSubmit && (
|
||||
<Button
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
ChangeEvent,
|
||||
FC,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
JSX,
|
||||
} from "react";
|
||||
import { ChevronDownIcon } from "./icons/icons";
|
||||
import { forwardRef, useEffect, useRef, useState, JSX } from "react";
|
||||
import { FiCheck, FiChevronDown, FiInfo } from "react-icons/fi";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgPlus } from "@opal/icons";
|
||||
export interface Option<T> {
|
||||
name: string;
|
||||
value: T;
|
||||
@@ -28,230 +16,6 @@ export interface Option<T> {
|
||||
|
||||
export type StringOrNumberOption = Option<string | number>;
|
||||
|
||||
function StandardDropdownOption<T>({
|
||||
index,
|
||||
option,
|
||||
handleSelect,
|
||||
}: {
|
||||
index: number;
|
||||
option: Option<T>;
|
||||
handleSelect: (option: Option<T>) => void;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={() => handleSelect(option)}
|
||||
className={`w-full text-left block px-4 py-2.5 text-sm bg-white dark:bg-neutral-800 hover:bg-neutral-50 dark:hover:bg-neutral-700 ${
|
||||
index !== 0 ? "border-t border-neutral-200 dark:border-neutral-700" : ""
|
||||
}`}
|
||||
role="menuitem"
|
||||
>
|
||||
<p className="font-medium text-xs text-neutral-900 dark:text-neutral-100">
|
||||
{option.name}
|
||||
</p>
|
||||
{option.description && (
|
||||
<p className="text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{option.description}
|
||||
</p>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export function SearchMultiSelectDropdown({
|
||||
options,
|
||||
onSelect,
|
||||
itemComponent,
|
||||
onCreate,
|
||||
onDelete,
|
||||
onSearchTermChange,
|
||||
initialSearchTerm = "",
|
||||
allowCustomValues = false,
|
||||
}: {
|
||||
options: StringOrNumberOption[];
|
||||
onSelect: (selected: StringOrNumberOption) => void;
|
||||
itemComponent?: (props: { option: StringOrNumberOption }) => JSX.Element;
|
||||
onCreate?: (name: string) => void;
|
||||
onDelete?: (name: string) => void;
|
||||
onSearchTermChange?: (term: string) => void;
|
||||
initialSearchTerm?: string;
|
||||
allowCustomValues?: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState(initialSearchTerm);
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleSelect = (option: StringOrNumberOption) => {
|
||||
onSelect(option);
|
||||
setIsOpen(false);
|
||||
setSearchTerm(""); // Clear search term after selection
|
||||
};
|
||||
|
||||
const filteredOptions = options.filter((option) =>
|
||||
option.name.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
|
||||
// Handle selecting a custom value not in the options list
|
||||
const handleCustomValueSelect = () => {
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
const customOption: StringOrNumberOption = {
|
||||
name: searchTerm,
|
||||
value: searchTerm,
|
||||
};
|
||||
onSelect(customOption);
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
// If allowCustomValues is enabled and there's text in the search field,
|
||||
// treat clicking outside as selecting the custom value
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, [allowCustomValues, searchTerm]);
|
||||
|
||||
useEffect(() => {
|
||||
setSearchTerm(initialSearchTerm);
|
||||
}, [initialSearchTerm]);
|
||||
|
||||
return (
|
||||
<div className="relative text-left w-full" ref={dropdownRef}>
|
||||
<div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder={
|
||||
allowCustomValues ? "Search or enter custom value..." : "Search..."
|
||||
}
|
||||
value={searchTerm}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const newValue = e.target.value;
|
||||
setSearchTerm(newValue);
|
||||
if (onSearchTermChange) {
|
||||
onSearchTermChange(newValue);
|
||||
}
|
||||
if (newValue) {
|
||||
setIsOpen(true);
|
||||
} else {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}}
|
||||
onFocus={() => setIsOpen(true)}
|
||||
onKeyDown={(e) => {
|
||||
if (
|
||||
e.key === "Enter" &&
|
||||
allowCustomValues &&
|
||||
searchTerm.trim() !== ""
|
||||
) {
|
||||
e.preventDefault();
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
}}
|
||||
className="inline-flex justify-between w-full px-4 py-2 text-sm bg-white dark:bg-transparent text-neutral-800 dark:text-neutral-200 border border-neutral-200 dark:border-neutral-700 rounded-md shadow-sm"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
className="absolute top-0 right-0 text-sm h-full px-2 border-l border-neutral-200 dark:border-neutral-700"
|
||||
aria-expanded={isOpen}
|
||||
aria-haspopup="true"
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
<ChevronDownIcon className="my-auto w-4 h-4 text-neutral-600 dark:text-neutral-400" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isOpen && (
|
||||
<div className="absolute z-10 mt-1 w-full rounded-md shadow-lg bg-white dark:bg-neutral-900 border border-neutral-200 dark:border-neutral-700 max-h-60 overflow-y-auto">
|
||||
<div
|
||||
role="menu"
|
||||
aria-orientation="vertical"
|
||||
aria-labelledby="options-menu"
|
||||
>
|
||||
{filteredOptions.map((option, index) =>
|
||||
itemComponent ? (
|
||||
<div
|
||||
key={option.name}
|
||||
onClick={() => {
|
||||
handleSelect(option);
|
||||
}}
|
||||
>
|
||||
{itemComponent({ option })}
|
||||
</div>
|
||||
) : (
|
||||
<StandardDropdownOption
|
||||
key={index}
|
||||
option={option}
|
||||
index={index}
|
||||
handleSelect={handleSelect}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
|
||||
{allowCustomValues &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<Button
|
||||
className="w-full"
|
||||
role="menuitem"
|
||||
onClick={handleCustomValueSelect}
|
||||
leftIcon={SvgPlus}
|
||||
>
|
||||
Use "{searchTerm}" as custom value
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{onCreate &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<>
|
||||
<div className="border-t border-background-300"></div>
|
||||
<Button
|
||||
className="w-full"
|
||||
role="menuitem"
|
||||
onClick={() => {
|
||||
onCreate(searchTerm);
|
||||
setIsOpen(false);
|
||||
setSearchTerm("");
|
||||
}}
|
||||
leftIcon={SvgPlus}
|
||||
>
|
||||
Create label "{searchTerm}"
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
|
||||
{filteredOptions.length === 0 &&
|
||||
((!onCreate && !allowCustomValues) ||
|
||||
searchTerm.trim() === "") && (
|
||||
<div className="px-4 py-2.5 text-sm text-text-500">
|
||||
No matches found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const CustomDropdown = ({
|
||||
children,
|
||||
dropdown,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { FormikProps, ErrorMessage } from "formik";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgX } from "@opal/icons";
|
||||
export type GenericMultiSelectFormType<T extends string> = {
|
||||
@@ -84,15 +84,11 @@ export function GenericMultiSelect<
|
||||
const selectedIds = (formikProps.values[fieldName] as number[]) || [];
|
||||
const selectedItems = items.filter((item) => selectedIds.includes(item.id));
|
||||
|
||||
const handleSelect = (option: { name: string; value: string | number }) => {
|
||||
const handleSelect = (itemId: number) => {
|
||||
if (disabled) return;
|
||||
const currentIds = (formikProps.values[fieldName] as number[]) || [];
|
||||
const numValue =
|
||||
typeof option.value === "string"
|
||||
? parseInt(option.value, 10)
|
||||
: option.value;
|
||||
if (!currentIds.includes(numValue)) {
|
||||
formikProps.setFieldValue(fieldName, [...currentIds, numValue]);
|
||||
if (!currentIds.includes(itemId)) {
|
||||
formikProps.setFieldValue(fieldName, [...currentIds, itemId]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -118,14 +114,24 @@ export function GenericMultiSelect<
|
||||
)}
|
||||
|
||||
<div className={cn(disabled && "opacity-50 pointer-events-none")}>
|
||||
<SearchMultiSelectDropdown
|
||||
<InputComboBox
|
||||
placeholder="Search..."
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={(selectedValue) => {
|
||||
const numValue = parseInt(selectedValue, 10);
|
||||
if (!isNaN(numValue)) {
|
||||
handleSelect(numValue);
|
||||
}
|
||||
}}
|
||||
options={items
|
||||
.filter((item) => !selectedIds.includes(item.id))
|
||||
.map((item) => ({
|
||||
name: item.name,
|
||||
value: item.id,
|
||||
label: item.name,
|
||||
value: String(item.id),
|
||||
}))}
|
||||
onSelect={handleSelect}
|
||||
strict
|
||||
leftSearchIcon
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -619,6 +619,9 @@ export function FederatedConnectorForm({
|
||||
);
|
||||
}
|
||||
|
||||
const channelInputPlaceholder =
|
||||
"Type channel name or regex pattern and press Enter";
|
||||
|
||||
return (
|
||||
<>
|
||||
{Object.entries(formState.configurationSchema).map(
|
||||
@@ -688,9 +691,9 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
}}
|
||||
placeholder={
|
||||
fieldSpec.example &&
|
||||
Array.isArray(fieldSpec.example)
|
||||
? `e.g., ${fieldSpec.example.join(", ")}`
|
||||
fieldKey === "channels" ||
|
||||
fieldKey === "exclude_channels"
|
||||
? channelInputPlaceholder
|
||||
: "Type and press Enter to add an item"
|
||||
}
|
||||
disabled={disableSlackChannelInput(fieldKey)}
|
||||
|
||||
@@ -651,6 +651,7 @@ export function useLlmManager(
|
||||
|
||||
if (modelName) {
|
||||
const model = parseLlmDescriptor(modelName);
|
||||
// If we have no parsed modelName, try to find the provider by the raw modelName string
|
||||
if (!(model.modelName && model.modelName.length > 0)) {
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
@@ -666,14 +667,36 @@ export function useLlmManager(
|
||||
}
|
||||
}
|
||||
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
// If we have parsed provider info, try to find that specific provider.
|
||||
// This ensures we don't incorrectly match a model to the wrong provider
|
||||
// when the same model name exists across multiple providers (e.g., gpt-5 in Azure and OpenAI)
|
||||
if (model.provider && model.provider.length > 0) {
|
||||
const matchingProvider = llmProviders.find(
|
||||
(p) =>
|
||||
p.provider === model.provider &&
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
if (matchingProvider) {
|
||||
return {
|
||||
...model,
|
||||
name: matchingProvider.name,
|
||||
provider: matchingProvider.provider,
|
||||
};
|
||||
}
|
||||
// Provider info was present but not found - fall through to default
|
||||
} else {
|
||||
// Only search by model name when no provider info was parsed
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
|
||||
if (provider) {
|
||||
return { ...model, provider: provider.provider, name: provider.name };
|
||||
if (provider) {
|
||||
return { ...model, provider: provider.provider, name: provider.name };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,12 +734,17 @@ export function useLlmManager(
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number>(() => {
|
||||
llmUpdate();
|
||||
|
||||
if (currentChatSession?.current_temperature_override != null) {
|
||||
// Derive Anthropic check from chat session since currentLlm isn't populated yet
|
||||
const sessionModel = currentChatSession.current_alternate_model
|
||||
? parseLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
: null;
|
||||
const isAnthropicModel = sessionModel
|
||||
? isAnthropic(sessionModel.provider, sessionModel.modelName)
|
||||
: false;
|
||||
return Math.min(
|
||||
currentChatSession.current_temperature_override,
|
||||
isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0
|
||||
isAnthropicModel ? 1.0 : 2.0
|
||||
);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
@@ -727,8 +755,20 @@ export function useLlmManager(
|
||||
});
|
||||
|
||||
const maxTemperature = useMemo(() => {
|
||||
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
|
||||
}, [currentLlm]);
|
||||
// Check currentLlm first, fall back to chat session model if currentLlm isn't populated
|
||||
if (currentLlm.provider) {
|
||||
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
|
||||
}
|
||||
const sessionModel = currentChatSession?.current_alternate_model
|
||||
? parseLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
: null;
|
||||
if (sessionModel?.provider) {
|
||||
return isAnthropic(sessionModel.provider, sessionModel.modelName)
|
||||
? 1.0
|
||||
: 2.0;
|
||||
}
|
||||
return 2.0; // Default max when no model info available
|
||||
}, [currentLlm, currentChatSession]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
@@ -770,13 +810,12 @@ export function useLlmManager(
|
||||
]);
|
||||
|
||||
const updateTemperature = (temperature: number) => {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
setTemperature(Math.min(temperature, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
}
|
||||
const clampedTemp = isAnthropic(currentLlm.provider, currentLlm.modelName)
|
||||
? Math.min(temperature, 1.0)
|
||||
: temperature;
|
||||
setTemperature(clampedTemp);
|
||||
if (chatSession) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, temperature);
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, clampedTemp);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1029,12 +1068,16 @@ export function useSourcePreferences({
|
||||
}
|
||||
setSourcesInitialized(true);
|
||||
}
|
||||
}, [availableSources, sourcesInitialized, setSelectedSources]);
|
||||
}, [
|
||||
availableSources,
|
||||
configuredSources,
|
||||
sourcesInitialized,
|
||||
setSelectedSources,
|
||||
]);
|
||||
|
||||
const enableSources = (sources: SourceMetadata[]) => {
|
||||
const allSourceMetadata = getConfiguredSources(availableSources);
|
||||
setSelectedSources([...sources]);
|
||||
persistSourcePreferencesState(sources, allSourceMetadata);
|
||||
persistSourcePreferencesState(sources, configuredSources);
|
||||
};
|
||||
|
||||
const enableAllSources = () => {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export function useProjects() {
|
||||
const { data, error, mutate } = useSWR<Project[]>(
|
||||
"/api/user/projects/",
|
||||
"/api/user/projects",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
|
||||
@@ -13,7 +13,7 @@ const ConnectionProviderIcon = memo(({ icon }: ConnectionProviderIconProps) => {
|
||||
<SvgArrowExchange className="w-3 h-3 stroke-text-04" />
|
||||
</div>
|
||||
<div className="w-7 h-7 flex items-center justify-center">
|
||||
<SvgOnyxLogo width={24} height={24} className="fill-text-04" />
|
||||
<SvgOnyxLogo size={24} className="fill-text-04" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -353,7 +353,15 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
>
|
||||
<Icon className="w-[1.5rem] h-[1.5rem] stroke-text-04" />
|
||||
{/*
|
||||
The `h-[1.5rem]` and `w-[1.5rem]` were added as backups here.
|
||||
However, prop-resolution technically resolves to choosing classNames over size props, so technically the `size={24}` is the backup.
|
||||
We specify both to be safe.
|
||||
|
||||
# Note
|
||||
1.5rem === 24px
|
||||
*/}
|
||||
<Icon className="stroke-text-04 h-[1.5rem] w-[1.5rem]" size={24} />
|
||||
{onClose && (
|
||||
<div
|
||||
tabIndex={-1}
|
||||
|
||||
@@ -982,9 +982,10 @@ export default function AgentEditorPage({
|
||||
const hasUploadingFiles = values.user_file_ids.some(
|
||||
(fileId: string) => {
|
||||
const status = fileStatusMap.get(fileId);
|
||||
return (
|
||||
status === undefined || status === UserFileStatus.UPLOADING
|
||||
);
|
||||
if (status === undefined) {
|
||||
return fileId.startsWith("temp_");
|
||||
}
|
||||
return status === UserFileStatus.UPLOADING;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user