mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-12 10:22:42 +00:00
Compare commits
12 Commits
codex/agen
...
v3.2.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7332adb1e6 | ||
|
|
0ab1b76765 | ||
|
|
40cd0a78a3 | ||
|
|
28d8c5de46 | ||
|
|
004092767f | ||
|
|
eb4689a669 | ||
|
|
47dd8973c1 | ||
|
|
a1403ef78c | ||
|
|
f96b9d6804 | ||
|
|
711651276c | ||
|
|
3731110cf9 | ||
|
|
8fb7a8718e |
@@ -13,6 +13,7 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -107,12 +108,13 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
For self-hosted: counts all active users.
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -129,6 +131,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,6 +24,7 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -65,12 +68,25 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -209,12 +225,37 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -60,8 +60,10 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -255,15 +257,13 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,12 +277,25 @@ def bulk_fetch_issues(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -6,6 +7,14 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -432,7 +519,6 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -662,7 +748,6 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -695,62 +780,37 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
for msg in replies:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -976,7 +1036,16 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -993,8 +1062,16 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1010,7 +1087,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -668,6 +695,15 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -684,7 +720,9 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -702,7 +740,8 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
return [
|
||||
search_queries = [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -96,6 +96,32 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -392,6 +418,26 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1356,6 +1402,19 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -111,6 +111,43 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -1174,16 +1211,17 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
headers: dict[str, str] = {
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1206,8 +1244,12 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1300,13 +1342,18 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1390,8 +1437,12 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
api_key=api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1448,7 +1499,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1523,8 +1574,12 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1613,8 +1668,12 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
|
||||
@@ -9,6 +9,7 @@ from unittest.mock import patch
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
@@ -214,3 +215,43 @@ class TestCheckSeatAvailabilityMultiTenant:
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
|
||||
class TestGetUsedSeatsAccountTypeFiltering:
|
||||
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
|
||||
"""SERVICE_ACCOUNT users should not count toward seats."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 5
|
||||
|
||||
result = get_used_seats()
|
||||
|
||||
assert result == 5
|
||||
# Inspect the compiled query to verify account_type filter
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "SERVICE_ACCOUNT" in compiled
|
||||
# BOT should NOT be excluded
|
||||
assert "BOT" not in compiled
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
|
||||
"""EXT_PERM_USER exclusion should still be present."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 3
|
||||
|
||||
get_used_seats()
|
||||
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "EXT_PERM_USER" in compiled
|
||||
|
||||
@@ -6,6 +6,7 @@ import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
@@ -145,3 +146,29 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
|
||||
|
||||
def test_bulk_fetch_respects_api_batch_limit() -> None:
|
||||
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
|
||||
client = _mock_jira_client()
|
||||
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
|
||||
all_ids = [str(i) for i in range(total_issues)]
|
||||
|
||||
batch_sizes: list[int] = []
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
batch_sizes.append(len(ids))
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, all_ids)
|
||||
|
||||
assert len(result) == total_issues
|
||||
# keeping this hardcoded because it's the documented limit
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
assert all(size <= 100 for size in batch_sizes)
|
||||
assert len(batch_sizes) == 4
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -505,6 +505,7 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -9,7 +10,9 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
16
web/package-lock.json
generated
16
web/package-lock.json
generated
@@ -47,6 +47,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
@@ -8843,6 +8844,15 @@
|
||||
"react": ">= 16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"toggle-selection": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.46.0",
|
||||
"hasInstallScript": true,
|
||||
@@ -17426,6 +17436,12 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/toggle-selection": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
|
||||
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/toposort": {
|
||||
"version": "2.0.2",
|
||||
"license": "MIT"
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
|
||||
@@ -17,6 +17,7 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import CollapsibleSection from "@/app/admin/agents/CollapsibleSection";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { RadioGroup } from "@/components/ui/radio-group";
|
||||
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
@@ -126,6 +127,24 @@ export function SlackChannelConfigFormFields({
|
||||
return documentSets.filter((ds) => !documentSetContainsSync(ds));
|
||||
}, [documentSets]);
|
||||
|
||||
const searchAgentOptions = useMemo(
|
||||
() =>
|
||||
availableAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[availableAgents]
|
||||
);
|
||||
|
||||
const nonSearchAgentOptions = useMemo(
|
||||
() =>
|
||||
nonSearchAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[nonSearchAgents]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const invalidSelected = values.document_sets.filter((dsId: number) =>
|
||||
unselectableSets.some((us) => us.id === dsId)
|
||||
@@ -355,12 +374,14 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={availableAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={searchAgentOptions}
|
||||
strict
|
||||
/>
|
||||
{viewSyncEnabledAgents && syncEnabledAgents.length > 0 && (
|
||||
<div className="mt-4">
|
||||
@@ -419,12 +440,14 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={nonSearchAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={nonSearchAgentOptions}
|
||||
strict
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -210,8 +210,10 @@ export default function MultiModelResponseView({
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
|
||||
// Persist preferred response to backend + update local tree so the
|
||||
// input bar unblocks (awaitingPreferredSelection clears).
|
||||
// Persist preferred response + sync `latestChildNodeId`. Backend's
|
||||
// `set_preferred_response` updates `latest_child_message_id`; if the
|
||||
// frontend chain walk disagrees, the next follow-up fails with
|
||||
// "not on the latest mainline".
|
||||
if (parentMessage?.messageId && response.messageId && currentSessionId) {
|
||||
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
|
||||
(err) => console.error("Failed to persist preferred response:", err)
|
||||
@@ -227,6 +229,7 @@ export default function MultiModelResponseView({
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: response.messageId,
|
||||
latestChildNodeId: response.nodeId,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
|
||||
@@ -694,6 +694,25 @@ export function useLlmManager(
|
||||
prevAgentIdRef.current = liveAgent?.id;
|
||||
}, [liveAgent?.id]);
|
||||
|
||||
// Clear manual override when arriving at a *different* existing session
|
||||
// from any previously-seen defined session. Tracks only the last
|
||||
// *defined* session id so a round-trip through new-chat (A → undefined
|
||||
// → B) still resets, while A → undefined (new-chat) preserves it.
|
||||
const prevDefinedSessionIdRef = useRef<string | undefined>(undefined);
|
||||
useEffect(() => {
|
||||
const nextId = currentChatSession?.id;
|
||||
if (
|
||||
nextId !== undefined &&
|
||||
prevDefinedSessionIdRef.current !== undefined &&
|
||||
nextId !== prevDefinedSessionIdRef.current
|
||||
) {
|
||||
setUserHasManuallyOverriddenLLM(false);
|
||||
}
|
||||
if (nextId !== undefined) {
|
||||
prevDefinedSessionIdRef.current = nextId;
|
||||
}
|
||||
}, [currentChatSession?.id]);
|
||||
|
||||
function getValidLlmDescriptor(
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor {
|
||||
@@ -715,8 +734,9 @@ export function useLlmManager(
|
||||
|
||||
if (llmProviders === undefined || llmProviders === null) {
|
||||
resolved = manualLlm;
|
||||
} else if (userHasManuallyOverriddenLLM && !currentChatSession) {
|
||||
// User has overridden in this session and switched to a new session
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// Manual override wins over session's `current_alternate_model`.
|
||||
// Cleared on cross-session navigation by the effect above.
|
||||
resolved = manualLlm;
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
@@ -728,8 +748,6 @@ export function useLlmManager(
|
||||
liveAgent.llm_model_version_override,
|
||||
llmProviders
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
resolved = manualLlm;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
user.preferences.default_model,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import copy from "copy-to-clipboard";
|
||||
import { Button, ButtonProps } from "@opal/components";
|
||||
import { SvgAlertTriangle, SvgCheck, SvgCopy } from "@opal/icons";
|
||||
|
||||
@@ -40,26 +41,19 @@ export default function CopyIconButton({
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if Clipboard API is available
|
||||
if (!navigator.clipboard) {
|
||||
throw new Error("Clipboard API not available");
|
||||
}
|
||||
|
||||
// If HTML content getter is provided, copy both HTML and plain text
|
||||
if (getHtmlContent) {
|
||||
if (navigator.clipboard && getHtmlContent) {
|
||||
const htmlContent = getHtmlContent();
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob([htmlContent], { type: "text/html" }),
|
||||
"text/plain": new Blob([text], { type: "text/plain" }),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
}
|
||||
// Default: plain text only
|
||||
else {
|
||||
} else if (navigator.clipboard) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else if (!copy(text)) {
|
||||
throw new Error("copy-to-clipboard returned false");
|
||||
}
|
||||
|
||||
// Show "copied" state
|
||||
setCopyState("copied");
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
|
||||
@@ -159,9 +159,12 @@ export default function ModelSelector({
|
||||
);
|
||||
|
||||
if (!isMultiModel) {
|
||||
// Stable key — keying on model would unmount the pill
|
||||
// on change and leave Radix's anchorRef detached,
|
||||
// flashing the closing popover at (0,0).
|
||||
return (
|
||||
<OpenButton
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
key="single-model-pill"
|
||||
icon={ProviderIcon}
|
||||
onClick={(e: React.MouseEvent) =>
|
||||
handlePillClick(index, e.currentTarget as HTMLElement)
|
||||
|
||||
@@ -425,16 +425,27 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [multiModel.isMultiModelActive]);
|
||||
|
||||
// Sync single-model selection to llmManager so the submission path
|
||||
// uses the correct provider/version (replaces the old LLMPopover sync).
|
||||
// Sync single-model selection to llmManager so the submission path uses
|
||||
// the correct provider/version. Guard against echoing derived state back
|
||||
// — only call updateCurrentLlm when the selection actually differs from
|
||||
// currentLlm, otherwise the initial [] → [currentLlmModel] sync would
|
||||
// pin `userHasManuallyOverriddenLLM=true` with whatever was resolved
|
||||
// first (often the default model before the session's alt_model loads).
|
||||
useEffect(() => {
|
||||
if (multiModel.selectedModels.length === 1) {
|
||||
const model = multiModel.selectedModels[0]!;
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
const current = llmManager.currentLlm;
|
||||
if (
|
||||
model.provider !== current.provider ||
|
||||
model.modelName !== current.modelName ||
|
||||
model.name !== current.name
|
||||
) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [multiModel.selectedModels]);
|
||||
|
||||
@@ -871,15 +882,20 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
agent={liveAgent}
|
||||
isDefaultAgent={isDefaultAgent}
|
||||
/>
|
||||
{liveAgent && !llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
{!isSearch &&
|
||||
!(
|
||||
state.phase === "idle" && state.appMode === "search"
|
||||
) &&
|
||||
liveAgent &&
|
||||
!llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
<Spacer rem={1.5} />
|
||||
</Fade>
|
||||
|
||||
@@ -50,7 +50,7 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -52,7 +52,7 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -52,7 +52,7 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
Reference in New Issue
Block a user