Compare commits

..

47 Commits

Author SHA1 Message Date
Raunak Bhagat
c78fe275af refactor: Popover cleanup (#7356) 2026-01-12 12:08:30 +00:00
Raunak Bhagat
c935c4808f fix: More actions cards fixes (#7358) 2026-01-12 03:27:42 -08:00
Raunak Bhagat
4ebcfef541 fix: Fix actions cards (#7357) 2026-01-12 10:57:22 +00:00
SubashMohan
e320ef9d9c Fix/agent creation files (#7346) 2026-01-12 07:00:47 +00:00
Nikolas Garza
9e02438af5 chore: standardize password/secret inputs and update per design docs (#7316) 2026-01-12 06:26:09 +00:00
Danelegend
177e097ddb fix(chat): newly created chats being marked as failed (#7310)
Co-authored-by: Dane Urban <durban@Danes-MacBook-Pro.local>
2026-01-12 02:02:49 +00:00
Wenxi
9ecd47ec31 feat: in app notifications for changelog (#7253) 2026-01-12 01:09:04 +00:00
Nikolas Garza
83f3d29b10 fix: stop federated OAuth modal from appearing permanently after skips (#7351) 2026-01-11 22:20:13 +00:00
Yuhong Sun
12e668cc0f feat: Deep Research Replay (#7340) 2026-01-11 22:17:09 +00:00
SubashMohan
afe8376d5e feat: Exclude image generation providers from LLM fetch in API calls (#7348) 2026-01-11 21:13:25 +00:00
Wenxi
082ef3e096 fix: always start onboarding at first step and track by user (#7315) 2026-01-11 21:03:17 +00:00
Nikolas Garza
cb2951a1c0 perf: switch BeautifulSoup parser from html.parser to lxml for web crawler (#7350) 2026-01-11 20:46:35 +00:00
Corey Auger
eda5598af5 fix: update docs link (#7349)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-11 12:44:48 -08:00
Justin Tahara
0bbb4b6988 fix(ui): Action Strikethrough when not configured (#7273) 2026-01-11 11:21:17 +00:00
Jamison Lahman
4768aadb20 refactor(fe): WelcomeMessage nits (#7344) 2026-01-10 22:01:48 -08:00
Jamison Lahman
e05e85e782 fix(fe): "Pick a date range" button wrapping (#7343) 2026-01-10 21:22:20 -08:00
Jamison Lahman
6408f61307 fix(fe): avoid internal table scroll on query history page (#7342) 2026-01-10 20:39:17 -08:00
Jamison Lahman
5a5cd51e4f fix(fe): SidebarTabs are Links (#7341) 2026-01-10 20:01:31 -08:00
Danelegend
7c047c47a0 fix(chat): Chat in-progress messages (#7318)
Co-authored-by: Dane Urban <durban@Danes-MacBook-Pro.local>
2026-01-11 00:29:39 +00:00
Evan Lohn
22138bbb33 fix: vertex prompt caching (#7339)
Co-authored-by: Weves <chrisweaver101@gmail.com>
2026-01-11 00:23:39 +00:00
Chris Weaver
7cff1064a8 chore: reenable auto update test (#7146) 2026-01-10 16:00:48 -08:00
Wenxi
deeb6fdcd2 fix: anonymous users cookie and admin panel config (#7321) 2026-01-10 15:12:27 -08:00
Chris Weaver
3e7f4e0aa5 fix: auto-sync (#7337) 2026-01-10 13:43:40 -08:00
Raunak Bhagat
ac73671e35 refactor: Components updates (#7308) 2026-01-10 06:30:39 +00:00
Raunak Bhagat
3c20d132e0 feat: Modal updates (#7306) 2026-01-10 05:13:09 +00:00
Yuhong Sun
0e3e7eb4a2 feat: Create new chat session button after msg send (#7332)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-01-10 04:56:54 +00:00
Yuhong Sun
c85aebe8ab Tables (#7333) 2026-01-09 20:40:15 -08:00
Yuhong Sun
a47e6a3146 feat: Enable triple click on content in the chat (#7331)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-01-09 20:37:36 -08:00
Jamison Lahman
1e61737e03 fix(fe): Tags have consistent height on hover (#7328) 2026-01-09 20:20:36 -08:00
Wenxi
c7fc1cd5ae chore: allow tenant cleanup script to skip control plane if tenant not found (#7290) 2026-01-10 00:17:26 +00:00
roshan
e2b60bf67c feat(posthog): track message origin analytics in posthog (#7313) 2026-01-10 00:11:17 +00:00
Danelegend
f4d4d14286 fix(chat): post llm loop callback (#7309)
Co-authored-by: Dane Urban <durban@Danes-MacBook-Pro.local>
2026-01-09 23:53:22 +00:00
Yuhong Sun
1c24bc6ea2 Opensearch README (#7327) 2026-01-09 15:53:22 -08:00
Yuhong Sun
cacbd18dcd feat: Opensearch README (#7325) 2026-01-09 15:28:08 -08:00
Nikolas Garza
8527b83b15 fix(sidebar): Allow unpinning all agents and fix icon flicker (#7241) 2026-01-09 14:20:46 -08:00
Nikolas Garza
33e37a1846 fix: make autocomplete opt in (#7317) 2026-01-09 20:04:22 +00:00
Jamison Lahman
d454d8a878 fix(chat): wide tables can be scrolled (#7311) 2026-01-09 19:07:40 +00:00
roshan
00ad65a6a8 feat: chrome extension (#6704) 2026-01-09 18:45:23 +00:00
Nikolas Garza
dac60d403c fix(chat): show "User has stopped generation" indicator when user cancels (#7312) 2026-01-09 18:14:35 +00:00
Evan Lohn
6256b2854d chore: bump indexing usage (#7307) 2026-01-09 17:46:27 +00:00
Danelegend
8acb8e191d fix(chat): use url when name unknown (#7278)
Co-authored-by: Dane Urban <durban@Danes-MacBook-Pro.local>
2026-01-09 17:16:20 +00:00
Evan Lohn
8c4cbddc43 fix: minor perm sync improvements (#7296) 2026-01-09 05:46:23 +00:00
Yuhong Sun
f6cd006bd6 chore: Refactor tool exceptions (#7280) 2026-01-09 04:01:12 +00:00
Jamison Lahman
0033934319 chore(perf): remove isEqual memoization check (#7304)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-09 03:20:37 +00:00
Raunak Bhagat
ff87b79d14 fix: Section layout component fix (#7305) 2026-01-08 19:25:33 -08:00
Raunak Bhagat
ebf18af7c9 refactor: UI components cleanup (#7301)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-01-09 03:09:20 +00:00
Raunak Bhagat
cf67ae962c feat: Add a new GeneralLayouts file and update layout components (#7297)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-01-09 02:50:21 +00:00
220 changed files with 6760 additions and 4655 deletions

View File

@@ -310,8 +310,9 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
- name: Start Docker containers

View File

@@ -301,7 +301,7 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers

View File

@@ -0,0 +1,49 @@
"""notifications constraint, sort index, and cleanup old notifications
Revision ID: 8405ca81cc83
Revises: a3c1a7904cd0
Create Date: 2026-01-07 16:43:44.855156
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8405ca81cc83"
down_revision = "a3c1a7904cd0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique index for notification deduplication.
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
#
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
# in unique constraints, but we want NULL == NULL for deduplication).
# The '{}' represents an empty JSONB object as the NULL replacement.
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
"""
)
# Create index for efficient notification sorting by user
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
ON notification (user_id, dismissed, first_shown DESC)
"""
)
# Clean up legacy 'reindex' notifications that are no longer needed
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")

View File

@@ -23,6 +23,7 @@ from onyx.db.models import User
from onyx.llm.factory import get_llm_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -100,6 +101,7 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
origin=MessageOrigin.API,
)
packets = stream_chat_message_objects(
@@ -203,6 +205,7 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
origin=MessageOrigin.API,
)
packets = stream_chat_message_objects(

View File

@@ -1,8 +1,5 @@
"""EE Usage limits - trial detection via billing information."""
from datetime import datetime
from datetime import timezone
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
@@ -31,13 +28,7 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
return True
if isinstance(billing_info, BillingInformation):
# Check if trial is active
if billing_info.trial_end is not None:
now = datetime.now(timezone.utc)
# Trial active if trial_end is in the future
# and subscription status indicates trialing
if billing_info.trial_end > now and billing_info.status == "trialing":
return True
return billing_info.status == "trialing"
return False

View File

@@ -124,6 +124,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",

View File

@@ -174,7 +174,7 @@ if AUTO_LLM_CONFIG_URL:
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)

View File

@@ -5,6 +5,9 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.well_known_providers.auto_update_service import (
sync_llm_models_from_github,
)
@shared_task(
@@ -26,24 +29,9 @@ def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
return None
try:
# Import here to avoid circular imports
from onyx.llm.well_known_providers.auto_update_service import (
fetch_llm_recommendations_from_github,
)
from onyx.llm.well_known_providers.auto_update_service import (
sync_llm_models_from_github,
)
# Fetch config from GitHub
config = fetch_llm_recommendations_from_github()
if not config:
task_logger.warning("Failed to fetch GitHub config")
return None
# Sync to database
with get_session_with_current_tenant() as db_session:
results = sync_llm_models_from_github(db_session, config)
results = sync_llm_models_from_github(db_session)
if results:
task_logger.info(f"Auto mode sync results: {results}")

View File

@@ -0,0 +1,57 @@
from uuid import UUID
from redis.client import Redis
# Redis key prefixes for chat message processing
PREFIX = "chatprocessing"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 30 * 60 # 30 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""
Generate the Redis key for a chat session processing a message.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string (tenant_id is automatically added by the Redis client)
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_processing_status(
chat_session_id: UUID, redis_client: Redis, value: bool
) -> None:
"""
Set or clear the fence for a chat session processing a message.
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
value: True to set the fence, False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if value:
redis_client.set(fence_key, 0, ex=FENCE_TTL)
else:
redis_client.delete(fence_key)
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session is processing a message.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
Returns:
True if the chat session is processing a message, False otherwise
"""
fence_key = _get_fence_key(chat_session_id)
return bool(redis_client.exists(fence_key))

View File

@@ -94,6 +94,7 @@ class ChatStateContainer:
def run_chat_loop_with_state_containers(
func: Callable[..., None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
@@ -196,3 +197,12 @@ def run_chat_loop_with_state_containers(
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -55,6 +55,7 @@ from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
@@ -117,6 +118,7 @@ def prepare_chat_message_request(
llm_override: LLMOverride | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
origin: MessageOrigin | None = None,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -144,6 +146,7 @@ def prepare_chat_message_request(
llm_override=llm_override,
allowed_tool_ids=allowed_tool_ids,
forced_tool_ids=forced_tool_ids,
origin=origin or MessageOrigin.UNKNOWN,
)

View File

@@ -505,7 +505,7 @@ def run_llm_loop(
# in-flight citations
# It can be cleaned up but not super trivial or worthwhile right now
just_ran_web_search = False
tool_responses, citation_mapping = run_tool_calls(
parallel_tool_call_results = run_tool_calls(
tool_calls=tool_calls,
tools=final_tools,
message_history=truncated_message_history,
@@ -516,6 +516,8 @@ def run_llm_loop(
max_concurrent_tools=None,
skip_search_query_expansion=has_called_search_tool,
)
tool_responses = parallel_tool_call_results.tool_responses
citation_mapping = parallel_tool_call_results.updated_citation_mapping
# Failure case, give something reasonable to the LLM to try again
if tool_calls and not tool_responses:

View File

@@ -5,10 +5,13 @@ An overview can be found in the README.md file in this directory.
import re
import traceback
from collections.abc import Callable
from uuid import UUID
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
from onyx.chat.chat_utils import convert_chat_history
@@ -45,6 +48,8 @@ from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.db.projects import get_project_token_count
from onyx.db.projects import get_user_files_from_project
@@ -78,20 +83,16 @@ from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def __init__(self, message: str, tool_name: str | None = None):
super().__init__(message)
self.tool_name = tool_name
def _extract_project_file_texts_and_images(
project_id: int | None,
user_id: UUID | None,
@@ -294,6 +295,8 @@ def handle_stream_message_objects(
tenant_id = get_current_tenant_id()
llm: LLM | None = None
chat_session: ChatSession | None = None
redis_client: Redis | None = None
user_id = user.id if user is not None else None
llm_user_identifier = (
@@ -339,6 +342,24 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
distinct_id=user.email if user else tenant_id,
event="user_message_sent",
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)
llm = get_llm_for_persona(
persona=persona,
user=user,
@@ -536,10 +557,27 @@ def handle_stream_message_objects(
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, redis_client)
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
value=True,
)
# Use external state container if provided, otherwise create internal one
# External container allows non-streaming callers to access accumulated state
state_container = external_state_container or ChatStateContainer()
def llm_loop_completion_callback(
state_container: ChatStateContainer,
) -> None:
llm_loop_completion_handle(
state_container=state_container,
db_session=db_session,
chat_session_id=str(chat_session.id),
is_connected=check_is_connected,
assistant_message=assistant_response,
)
# Run the LLM loop with explicit wrapper for stop signal handling
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
# for stop signals. run_llm_loop itself doesn't know about stopping.
@@ -555,6 +593,7 @@ def handle_stream_message_objects(
yield from run_chat_loop_with_state_containers(
run_deep_research_llm_loop,
llm_loop_completion_callback,
is_connected=check_is_connected,
emitter=emitter,
state_container=state_container,
@@ -571,6 +610,7 @@ def handle_stream_message_objects(
else:
yield from run_chat_loop_with_state_containers(
run_llm_loop,
llm_loop_completion_callback,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,
state_container=state_container,
@@ -588,51 +628,6 @@ def handle_stream_message_objects(
chat_session_id=str(chat_session.id),
)
# Determine if stopped by user
completed_normally = check_is_connected()
if not completed_normally:
logger.debug(f"Chat session {chat_session.id} stopped by user")
# Build final answer based on completion status
if completed_normally:
if state_container.answer_tokens is None:
raise RuntimeError(
"LLM run completed normally but did not return an answer."
)
final_answer = state_container.answer_tokens
else:
# Stopped by user - append stop message
if state_container.answer_tokens:
final_answer = (
state_container.answer_tokens
+ " ... The generation was stopped by the user here."
)
else:
final_answer = "The generation was stopped by the user."
# Build citation_docs_info from accumulated citations in state container
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
)
)
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
tool_calls=state_container.tool_calls,
db_session=db_session,
assistant_message=assistant_response,
is_clarification=state_container.is_clarification,
)
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -650,15 +645,7 @@ def handle_stream_message_objects(
error_msg = str(e)
stack_trace = traceback.format_exc()
if isinstance(e, ToolCallException):
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="TOOL_CALL_FAILED",
is_retryable=True,
details={"tool_name": e.tool_name} if e.tool_name else None,
)
elif llm:
if llm:
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
e, llm
)
@@ -690,7 +677,67 @@ def handle_stream_message_objects(
)
db_session.rollback()
return
finally:
try:
if redis_client is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
value=False,
)
except Exception:
logger.exception("Error in setting processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],
db_session: Session,
chat_session_id: str,
assistant_message: ChatMessage,
) -> None:
# Determine if stopped by user
completed_normally = is_connected()
# Build final answer based on completion status
if completed_normally:
if state_container.answer_tokens is None:
raise RuntimeError(
"LLM run completed normally but did not return an answer."
)
final_answer = state_container.answer_tokens
else:
# Stopped by user - append stop message
logger.debug(f"Chat session {chat_session_id} stopped by user")
if state_container.answer_tokens:
final_answer = (
state_container.answer_tokens
+ " ... \n\nGeneration was stopped by the user."
)
else:
final_answer = "The generation was stopped by the user."
# Build citation_docs_info from accumulated citations in state container
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
)
)
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
tool_calls=state_container.tool_calls,
db_session=db_session,
assistant_message=assistant_message,
is_clarification=state_container.is_clarification,
)
def stream_chat_message_objects(
@@ -739,6 +786,7 @@ def stream_chat_message_objects(
deep_research=new_msg_req.deep_research,
parent_message_id=new_msg_req.parent_message_id,
chat_session_id=new_msg_req.chat_session_id,
origin=new_msg_req.origin,
)
return handle_stream_message_objects(
new_msg_req=translated_new_msg_req,

View File

@@ -568,6 +568,7 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500))
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -995,3 +996,9 @@ COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
INSTANCE_TYPE = (
"managed"
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
)

View File

@@ -7,6 +7,7 @@ from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
ONYX_UTM_SOURCE = "onyx_app"
SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
@@ -235,6 +236,7 @@ class NotificationType(str, Enum):
PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
class BlobType(str, Enum):
@@ -422,6 +424,9 @@ class OnyxRedisLocks:
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
# Release notes
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"

View File

@@ -93,7 +93,7 @@ if __name__ == "__main__":
#### Docs Changes
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
connector in Onyx. Then create a Pull Request in https://github.com/onyx-dot-app/onyx-docs.
connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation).
### Before opening PR

View File

@@ -901,13 +901,16 @@ class OnyxConfluence:
space_key: str,
) -> list[dict[str, Any]]:
"""
This is a confluence server specific method that can be used to
This is a confluence server/data center specific method that can be used to
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions
on Confluence Server/Data Center. The REST API equivalent (expand=permissions)
is Cloud-only and not available on Data Center as of version 8.9.x.
If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC:
Confluence Admin -> General Configuration -> Further Configuration
-> Enable "Remote API (XML-RPC & SOAP)"
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
@@ -916,7 +919,18 @@ class OnyxConfluence:
"id": 7,
"params": [space_key],
}
response = self.post(url, data=data)
try:
response = self.post(url, data=data)
except HTTPError as e:
if e.response is not None and e.response.status_code == 401:
raise HTTPError(
"Unauthorized (401) when calling JSON-RPC API for space permissions. "
"This is likely because the Remote API is disabled. "
"To fix: Confluence Admin -> General Configuration -> Further Configuration "
"-> Enable 'Remote API (XML-RPC & SOAP)'",
response=e.response,
) from e
raise
logger.debug(f"jsonrpc response: {response}")
if not response.get("result"):
logger.warning(

View File

@@ -18,6 +18,7 @@ from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
is_atlassian_date_error,
@@ -57,7 +58,6 @@ logger = setup_logger()
ONE_HOUR = 3600
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
# Constants for Jira field names
@@ -683,7 +683,7 @@ class JiraConnector(
jira_client=self.jira_client,
jql=jql,
start=current_offset,
max_results=_JIRA_SLIM_PAGE_SIZE,
max_results=JIRA_SLIM_PAGE_SIZE,
all_issue_ids=checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
nextPageToken=checkpoint.cursor,
@@ -703,11 +703,11 @@ class JiraConnector(
)
)
current_offset += 1
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
self.update_checkpoint_for_next_run(
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE
)
prev_offset = current_offset

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Tuple
from uuid import UUID
@@ -181,7 +182,11 @@ def get_chat_sessions_by_user(
.correlate(ChatSession)
)
stmt = stmt.where(non_system_message_exists_subq)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()

View File

@@ -374,7 +374,7 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
def fetch_existing_llm_providers(
db_session: Session,
only_public: bool = False,
exclude_image_generation_providers: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
"""Fetch all LLM providers with optional filtering.
@@ -585,13 +585,12 @@ def update_default_vision_provider(
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
"""Fetch all LLM providers that are in Auto mode."""
return list(
db_session.scalars(
select(LLMProviderModel)
.where(LLMProviderModel.is_auto_mode == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
).all()
query = (
select(LLMProviderModel)
.where(LLMProviderModel.is_auto_mode.is_(True))
.options(selectinload(LLMProviderModel.model_configurations))
)
return list(db_session.scalars(query).all())
def sync_auto_mode_models(
@@ -620,7 +619,9 @@ def sync_auto_mode_models(
# Build the list of all visible models from the config
# All models in the config are visible (default + additional_visible_models)
recommended_visible_models = llm_recommendations.get_visible_models(provider.name)
recommended_visible_models = llm_recommendations.get_visible_models(
provider.provider
)
recommended_visible_model_names = [
model.name for model in recommended_visible_models
]
@@ -635,11 +636,12 @@ def sync_auto_mode_models(
).all()
}
# Remove models that are no longer in GitHub config
# Mark models that are no longer in GitHub config as not visible
for model_name, model in existing_models.items():
if model_name not in recommended_visible_model_names:
db_session.delete(model)
changes += 1
if model.is_visible:
model.is_visible = False
changes += 1
# Add or update models from GitHub config
for model_config in recommended_visible_models:
@@ -669,7 +671,7 @@ def sync_auto_mode_models(
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.name)
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1

View File

@@ -377,6 +377,17 @@ class Notification(Base):
postgresql.JSONB(), nullable=True
)
# Unique constraint ix_notification_user_type_data on (user_id, notif_type, additional_data)
# ensures notification deduplication for batch inserts. Defined in migration 8405ca81cc83.
__table_args__ = (
Index(
"ix_notification_user_sort",
"user_id",
"dismissed",
desc("first_shown"),
),
)
"""
Association Tables

View File

@@ -1,6 +1,11 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy import cast
from sqlalchemy import select
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
@@ -17,23 +22,33 @@ def create_notification(
title: str,
description: str | None = None,
additional_data: dict | None = None,
autocommit: bool = True,
) -> Notification:
# Check if an undismissed notification of the same type and data exists
# Previously, we only matched the first identical, undismissed notification
# Now, we assume some uniqueness to notifications
# If we previously issued a notification that was dismissed, we no longer issue a new one
# Normalize additional_data to match the unique index behavior
# The index uses COALESCE(additional_data, '{}'::jsonb)
# We need to match this logic in our query
additional_data_normalized = additional_data if additional_data is not None else {}
existing_notification = (
db_session.query(Notification)
.filter_by(
user_id=user_id,
notif_type=notif_type,
dismissed=False,
.filter_by(user_id=user_id, notif_type=notif_type)
.filter(
func.coalesce(Notification.additional_data, cast({}, postgresql.JSONB))
== additional_data_normalized
)
.filter(Notification.additional_data == additional_data)
.first()
)
if existing_notification:
# Update the last_shown timestamp
existing_notification.last_shown = func.now()
db_session.commit()
# Update the last_shown timestamp if the notification is not dismissed
if not existing_notification.dismissed:
existing_notification.last_shown = func.now()
if autocommit:
db_session.commit()
return existing_notification
# Create a new notification if none exists
@@ -48,7 +63,8 @@ def create_notification(
additional_data=additional_data,
)
db_session.add(notification)
db_session.commit()
if autocommit:
db_session.commit()
return notification
@@ -81,6 +97,11 @@ def get_notifications(
query = query.where(Notification.dismissed.is_(False))
if notif_type:
query = query.where(Notification.notif_type == notif_type)
# Sort: undismissed first, then by date (newest first)
query = query.order_by(
Notification.dismissed.asc(),
Notification.first_shown.desc(),
)
return list(db_session.execute(query).scalars().all())
@@ -99,6 +120,63 @@ def dismiss_notification(notification: Notification, db_session: Session) -> Non
db_session.commit()
def batch_dismiss_notifications(
notifications: list[Notification],
db_session: Session,
) -> None:
for notification in notifications:
notification.dismissed = True
db_session.commit()
def batch_create_notifications(
user_ids: list[UUID],
notif_type: NotificationType,
db_session: Session,
title: str,
description: str | None = None,
additional_data: dict | None = None,
) -> int:
"""
Create notifications for multiple users in a single batch operation.
Uses ON CONFLICT DO NOTHING for atomic idempotent inserts - if a user already
has a notification with the same (user_id, notif_type, additional_data), the
insert is silently skipped.
Returns the number of notifications created.
Relies on unique index on (user_id, notif_type, COALESCE(additional_data, '{}'))
"""
if not user_ids:
return 0
now = datetime.now(timezone.utc)
# Use empty dict instead of None to match COALESCE behavior in the unique index
additional_data_normalized = additional_data if additional_data is not None else {}
values = [
{
"user_id": uid,
"notif_type": notif_type.value,
"title": title,
"description": description,
"dismissed": False,
"last_shown": now,
"first_shown": now,
"additional_data": additional_data_normalized,
}
for uid in user_ids
]
stmt = insert(Notification).values(values).on_conflict_do_nothing()
result = db_session.execute(stmt)
db_session.commit()
# rowcount returns number of rows inserted (excludes conflicts)
# CursorResult has rowcount but session.execute type hints are too broad
return result.rowcount if result.rowcount >= 0 else 0 # type: ignore[attr-defined]
def update_notification_last_shown(
notification: Notification, db_session: Session
) -> None:

View File

@@ -0,0 +1,94 @@
"""Database functions for release notes functionality."""
from urllib.parse import urlencode
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import NotificationType
from onyx.configs.constants import ONYX_UTM_SOURCE
from onyx.db.models import User
from onyx.db.notification import batch_create_notifications
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
from onyx.server.features.release_notes.models import ReleaseNoteEntry
from onyx.utils.logger import setup_logger
logger = setup_logger()
def create_release_notifications_for_versions(
db_session: Session,
release_note_entries: list[ReleaseNoteEntry],
) -> int:
"""
Create release notes notifications for each release note entry.
Uses batch_create_notifications for efficient bulk insertion.
If a user already has a notification for a specific version (dismissed or not),
no new one is created (handled by unique constraint on additional_data).
Note: Entries should already be filtered by app_version before calling this
function. The filtering happens in _parse_mdx_to_release_note_entries().
Args:
db_session: Database session
release_note_entries: List of release note entries to notify about (pre-filtered)
Returns:
Total number of notifications created across all versions.
"""
if not release_note_entries:
logger.debug("No release note entries to notify about")
return 0
# Get active users and exclude API key users
user_ids = list(
db_session.scalars(
select(User.id).where( # type: ignore
User.is_active == True, # noqa: E712
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
)
).all()
)
total_created = 0
for entry in release_note_entries:
# Convert version to anchor format for external docs links
# v2.7.0 -> v2-7-0
version_anchor = entry.version.replace(".", "-")
# Build UTM parameters for tracking
utm_params = {
"utm_source": ONYX_UTM_SOURCE,
"utm_medium": "notification",
"utm_campaign": INSTANCE_TYPE,
"utm_content": f"release_notes-{entry.version}",
}
link = f"{DOCS_CHANGELOG_BASE_URL}#{version_anchor}?{urlencode(utm_params)}"
additional_data: dict[str, str] = {
"version": entry.version,
"link": link,
}
created_count = batch_create_notifications(
user_ids,
NotificationType.RELEASE_NOTES,
db_session,
title=entry.title,
description=f"Check out what's new in {entry.version}",
additional_data=additional_data,
)
total_created += created_count
logger.debug(
f"Created {created_count} release notes notifications "
f"(version {entry.version}, {len(user_ids)} eligible users)"
)
return total_created

View File

@@ -150,6 +150,9 @@ def generate_final_report(
is_deep_research=True,
)
# Save citation mapping to state_container so citations are persisted
state_container.set_citation_mapping(citation_processor.citation_to_doc)
final_report = llm_step_result.answer
if final_report is None:
raise ValueError("LLM failed to generate the final deep research report")

View File

@@ -0,0 +1,62 @@
# Opensearch Idiosyncrasies
## How it works at a high level
Opensearch has 2 phases, a `Search` phase and a `Fetch` phase. The `Search` phase works by getting the document scores on each
shard separately, then typically a fetch phase grabs all of the relevant fields/data for returning to the user. There is also
an intermediate phase (seemingly built specifically to handle hybrid search queries) which can run in between as a processor.
References:
https://docs.opensearch.org/latest/search-plugins/search-pipelines/search-processors/
https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
## How Hybrid queries work
Hybrid queries are basically parallel queries that each run through their own `Search` phase and do not interact in any way.
They also run across all the shards. It is not entirely clear what happens if a combination pipeline is not specified for them,
perhaps the scores are just summed.
When the normalization processor is applied to keyword/vector hybrid searches, documents that show up due to keyword match may
not also have showed up in the vector search and vice versa. In these situations, it just receives a 0 score for the missing
query component. Opensearch does not run another phase to recapture those missing values. The impact of this is that after
normalizing, the missing scores are 0 but this is a higher score than if it actually received a non-zero score.
This may not be immediately obvious so an explanation is included here. If it got a non-zero score instead, it must be lower
than all of the other scores of the list (otherwise it would have shown up). Therefore it would impact the normalization and
push the other scores higher so that it's not only the lowest score still, but now it's a differentiated lowest score. This is
not strictly the case in a multi-node setup but the high level concept approximately holds. So basically the 0 score is a form
of "minimum value clipping".
## On time decay and boosting
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
Same logic applies to additive boosting.
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
and only applies to the results of the completely independent `Search` phase queries. So if a time based boost (a separate
query which filters on recently updated documents) is added, it would not be able to introduce any new documents
to the set (since the new documents would have no keyword/vector score or already be present) since the 0 scores on keyword
and vector would make the docs which only came because of time filter very low scoring. This can however make some of the lower
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
being fetched and returned to the user. But there are other issues of including these:
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
contents. If there are lots of updates, this may miss
- There is not a good way to normalize this field, the best is to clip it on the bottom.
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
"unusual-ness" across distributions.
So while it is possible to apply time based boosting at the normalization stage (or specifically to the keyword score), we have
decided it is better to not apply it during the OpenSearch query.
Because of these limitations, Onyx in code applies further refinements, boostings, etc. based on OpenSearch providing an initial
filtering. The impact of time decay and boost should not be so big that we would need orders of magnitude more results back
from OpenSearch.
## Other concepts to be aware of
Within the `Search` phase, there are optional steps like Rescore but these are not useful for the combination/normalization
work that is relevant for the hybrid search. Since the Rescore happens prior to normalization, it's not able to provide any
meaningful operations to the query for our usage.
Because the Title is included in the Contents for both embedding and keyword searches, the Title scores are very low relative to
the actual full contents scoring. It is seen as a boost rather than a core scoring component. Time decay works similarly.

View File

@@ -164,7 +164,7 @@ def format_document_soup(
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
soup = bs4.BeautifulSoup(text, "html.parser")
soup = bs4.BeautifulSoup(text, "lxml")
return format_document_soup(soup)
@@ -174,7 +174,7 @@ def web_html_cleanup(
additional_element_types_to_discard: list[str] | None = None,
) -> ParsedHTML:
if isinstance(page_content, str):
soup = bs4.BeautifulSoup(page_content, "html.parser")
soup = bs4.BeautifulSoup(page_content, "lxml")
else:
soup = page_content

View File

@@ -6,15 +6,19 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import TransactionalContext
from onyx.access.access import get_access_for_user_files
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import NotificationType
from onyx.connectors.models import Document
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.notification import create_notification
from onyx.db.user_file import fetch_chunk_counts_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
@@ -194,6 +198,42 @@ class UserFileIndexingAdapter:
user_file_id_to_token_count=user_file_id_to_token_count,
)
def _notify_assistant_owners_if_files_ready(
self, user_files: list[UserFile]
) -> None:
"""
Check if all files for associated assistants are processed and notify owners.
Only sends notification when all files for an assistant are COMPLETED.
"""
for user_file in user_files:
if user_file.status == UserFileStatus.COMPLETED:
for assistant in user_file.assistants:
# Skip assistants without owners
if assistant.user_id is None:
continue
# Check if all OTHER files for this assistant are completed
# (we already know current file is completed from the outer check)
all_files_completed = all(
f.status == UserFileStatus.COMPLETED
for f in assistant.user_files
if f.id != user_file.id
)
if all_files_completed:
create_notification(
user_id=assistant.user_id,
notif_type=NotificationType.ASSISTANT_FILES_READY,
db_session=self.db_session,
title="Your files are ready!",
description=f"All files for agent {assistant.name} have been processed and are now available.",
additional_data={
"persona_id": assistant.id,
"link": f"/assistants/{assistant.id}",
},
autocommit=False,
)
def post_index(
self,
context: DocumentBatchPrepareContext,
@@ -204,7 +244,10 @@ class UserFileIndexingAdapter:
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
self.db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
self.db_session.query(UserFile)
.options(selectinload(UserFile.assistants).selectinload(Persona.user_files))
.filter(UserFile.id.in_(user_file_ids))
.all()
)
for user_file in user_files:
# don't update the status if the user file is being deleted
@@ -217,6 +260,10 @@ class UserFileIndexingAdapter:
user_file.token_count = result.user_file_id_to_token_count[
str(user_file.id)
]
# Notify assistant owners if all their files are now processed
self._notify_assistant_owners_if_files_ready(user_files)
self.db_session.commit()
# Store the plaintext in the file store for faster retrieval

View File

@@ -48,7 +48,7 @@ class VertexAIPromptCacheProvider(PromptCacheProvider):
cacheable_prefix=cacheable_prefix,
suffix=suffix,
continuation=continuation,
transform_cacheable=_add_vertex_cache_control,
transform_cacheable=None, # TODO: support explicit caching
)
def extract_cache_metadata(
@@ -89,6 +89,10 @@ def _add_vertex_cache_control(
not at the message level. This function converts string content to the array format
and adds cache_control to the last content block in each cacheable message.
"""
# NOTE: unfortunately we need a much more sophisticated mechnism to support
# explict caching with vertex in the presence of tools and system messages
# (since they're supposed to be stripped out when setting cache_control)
# so we're deferring this to a future PR.
updated: list[ChatCompletionMessage] = []
for message in messages:
mutated = dict(message)

View File

@@ -82,7 +82,6 @@ def fetch_llm_recommendations_from_github(
def sync_llm_models_from_github(
db_session: Session,
config: LLMRecommendations,
force: bool = False,
) -> dict[str, int]:
"""Sync models from GitHub config to database for all Auto mode providers.
@@ -101,19 +100,24 @@ def sync_llm_models_from_github(
Returns:
Dict of provider_name -> number of changes made.
"""
# Skip if we've already processed this version (unless forced)
last_updated_at = _get_cached_last_updated_at()
if not force and last_updated_at and config.updated_at <= last_updated_at:
logger.debug("GitHub config unchanged, skipping sync")
return {}
results: dict[str, int] = {}
# Get all providers in Auto mode
auto_providers = fetch_auto_mode_providers(db_session)
if not auto_providers:
logger.debug("No providers in Auto mode found")
return {}
# Fetch config from GitHub
config = fetch_llm_recommendations_from_github()
if not config:
logger.warning("Failed to fetch GitHub config")
return {}
# Skip if we've already processed this version (unless forced)
last_updated_at = _get_cached_last_updated_at()
if not force and last_updated_at and config.updated_at <= last_updated_at:
logger.debug("GitHub config unchanged, skipping sync")
_set_cached_last_updated_at(config.updated_at)
return {}

View File

@@ -35,6 +35,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.utils.logger import OnyxLoggingAdapter
srl = SlackRateLimiter()
@@ -236,6 +237,7 @@ def handle_regular_answer(
retrieval_details=retrieval_details,
rerank_settings=None, # Rerank customization supported in Slack flow
db_session=db_session,
origin=MessageOrigin.SLACKBOT,
)
# if it's a DM or ephemeral message, answer based on private documents.

View File

@@ -9,11 +9,13 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_notification
from onyx.db.notification import get_notification_by_id
from onyx.db.notification import get_notifications
from onyx.server.features.release_notes.utils import (
ensure_release_notes_fresh_and_notify,
)
from onyx.server.settings.models import Notification as NotificationModel
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/notifications")
@@ -22,9 +24,27 @@ def get_notifications_api(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[NotificationModel]:
"""
Get all undismissed notifications for the current user.
Note: also executes background checks that should create notifications.
Examples of checks that create new notifications:
- Checking for new release notes the user hasn't seen
- Checking for misconfigurations due to version changes
- Explicitly announcing breaking changes
"""
# If more background checks are added, this should be moved to a helper function
try:
ensure_release_notes_fresh_and_notify(db_session)
except Exception:
# Log exception but don't fail the entire endpoint
# Users can still see their existing notifications
logger.exception("Failed to check for release notes in notifications endpoint")
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=False)
for notif in get_notifications(user, db_session, include_dismissed=True)
]
return notifications

View File

@@ -0,0 +1,23 @@
"""Constants for release notes functionality."""
# GitHub source
GITHUB_RAW_BASE_URL = (
"https://raw.githubusercontent.com/onyx-dot-app/documentation/main"
)
GITHUB_CHANGELOG_RAW_URL = f"{GITHUB_RAW_BASE_URL}/changelog.mdx"
# Base URL for changelog documentation (used for notification links)
DOCS_CHANGELOG_BASE_URL = "https://docs.onyx.app/changelog"
FETCH_TIMEOUT = 60.0
# Redis keys (in shared namespace)
REDIS_KEY_PREFIX = "release_notes:"
REDIS_KEY_FETCHED_AT = f"{REDIS_KEY_PREFIX}fetched_at"
REDIS_KEY_ETAG = f"{REDIS_KEY_PREFIX}etag"
# Cache TTL: 24 hours
REDIS_CACHE_TTL = 60 * 60 * 24
# Auto-refresh threshold: 1 hour
AUTO_REFRESH_THRESHOLD_SECONDS = 60 * 60

View File

@@ -0,0 +1,11 @@
"""Pydantic models for release notes."""
from pydantic import BaseModel
class ReleaseNoteEntry(BaseModel):
"""A single version's release note entry."""
version: str # e.g., "v2.7.0"
date: str # e.g., "January 7th, 2026"
title: str # Display title for notifications: "Onyx v2.7.0 is available!"

View File

@@ -0,0 +1,242 @@
"""Utility functions for release notes parsing and caching."""
import re
from datetime import datetime
from datetime import timezone
import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
from onyx.server.features.release_notes.constants import REDIS_CACHE_TTL
from onyx.server.features.release_notes.constants import REDIS_KEY_ETAG
from onyx.server.features.release_notes.constants import REDIS_KEY_FETCHED_AT
from onyx.server.features.release_notes.models import ReleaseNoteEntry
from onyx.utils.logger import setup_logger
logger = setup_logger()
# ============================================================================
# Version Utilities
# ============================================================================
def is_valid_version(version: str) -> bool:
"""Check if version matches vX.Y.Z or vX.Y.Z-suffix.N pattern exactly."""
return bool(re.match(r"^v\d+\.\d+\.\d+(-[a-zA-Z]+\.\d+)?$", version))
def parse_version_tuple(version: str) -> tuple[int, int, int]:
"""Parse version string to tuple for semantic sorting."""
clean = re.sub(r"^v", "", version)
clean = re.sub(r"-.*$", "", clean)
parts = clean.split(".")
return (
int(parts[0]) if len(parts) > 0 else 0,
int(parts[1]) if len(parts) > 1 else 0,
int(parts[2]) if len(parts) > 2 else 0,
)
def is_version_gte(v1: str, v2: str) -> bool:
"""Check if v1 >= v2. Strips suffixes like -cloud.X or -beta.X."""
return parse_version_tuple(v1) >= parse_version_tuple(v2)
# ============================================================================
# MDX Parsing
# ============================================================================
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
all_entries = []
update_pattern = (
r'<Update\s+label="([^"]+)"\s+description="([^"]+)"'
r"(?:\s+tags=\{([^}]+)\})?[^>]*>"
r".*?"
r"</Update>"
)
for match in re.finditer(update_pattern, mdx_content, re.DOTALL):
version = match.group(1)
date = match.group(2)
if is_valid_version(version):
all_entries.append(
ReleaseNoteEntry(
version=version,
date=date,
title=f"Onyx {version} is available!",
)
)
if not all_entries:
raise ValueError("Could not parse any release note entries from MDX.")
# Filter to valid versions >= __version__
if __version__ and is_valid_version(__version__):
entries = [
entry for entry in all_entries if is_version_gte(entry.version, __version__)
]
else:
# If not recognized version
# likely `development` and we should show all entries
entries = all_entries
return entries
# ============================================================================
# Cache Helpers (ETag + timestamp only)
# ============================================================================
def get_cached_etag() -> str | None:
"""Get the cached GitHub ETag from Redis."""
redis_client = get_shared_redis_client()
try:
etag = redis_client.get(REDIS_KEY_ETAG)
if etag:
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
return None
except Exception as e:
logger.error(f"Failed to get cached etag from Redis: {e}")
return None
def get_last_fetch_time() -> datetime | None:
"""Get the last fetch timestamp from Redis."""
redis_client = get_shared_redis_client()
try:
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
if not fetched_at_str:
return None
decoded = (
fetched_at_str.decode("utf-8")
if isinstance(fetched_at_str, bytes)
else str(fetched_at_str)
)
last_fetch = datetime.fromisoformat(decoded)
# Defensively ensure timezone awareness
# fromisoformat() returns naive datetime if input lacks timezone
if last_fetch.tzinfo is None:
# Assume UTC for naive datetimes
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if timezone-aware
last_fetch = last_fetch.astimezone(timezone.utc)
return last_fetch
except Exception as e:
logger.error(f"Failed to get last fetch time from Redis: {e}")
return None
def save_fetch_metadata(etag: str | None) -> None:
"""Save ETag and fetch timestamp to Redis."""
redis_client = get_shared_redis_client()
now = datetime.now(timezone.utc)
try:
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
if etag:
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
except Exception as e:
logger.error(f"Failed to save fetch metadata to Redis: {e}")
def is_cache_stale() -> bool:
"""Check if we should fetch from GitHub."""
last_fetch = get_last_fetch_time()
if last_fetch is None:
return True
age = datetime.now(timezone.utc) - last_fetch
return age.total_seconds() > AUTO_REFRESH_THRESHOLD_SECONDS
# ============================================================================
# Main Function
# ============================================================================
def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
"""
Check for new release notes and create notifications if needed.
Called from /api/notifications endpoint. Uses ETag for efficient
GitHub requests. Database handles notification deduplication.
Since all users will trigger this via notification fetch,
uses Redis lock to prevent concurrent GitHub requests when cache is stale.
"""
if not is_cache_stale():
return
# Acquire lock to prevent concurrent fetches
redis_client = get_shared_redis_client()
lock = redis_client.lock(
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
timeout=90, # 90 second timeout for the lock
)
# Non-blocking acquire - if we can't get the lock, another request is handling it
acquired = lock.acquire(blocking=False)
if not acquired:
logger.debug("Another request is already fetching release notes, skipping.")
return
try:
logger.debug("Checking GitHub for release notes updates.")
# Use ETag for conditional request
headers: dict[str, str] = {}
etag = get_cached_etag()
if etag:
headers["If-None-Match"] = etag
try:
response = httpx.get(
GITHUB_CHANGELOG_RAW_URL,
headers=headers,
timeout=FETCH_TIMEOUT,
follow_redirects=True,
)
if response.status_code == 304:
# Content unchanged, just update timestamp
logger.debug("Release notes unchanged (304).")
save_fetch_metadata(etag)
return
response.raise_for_status()
# Parse and create notifications
entries = parse_mdx_to_release_note_entries(response.text)
new_etag = response.headers.get("ETag")
save_fetch_metadata(new_etag)
# Create notifications, sorted semantically to create them in chronological order
entries = sorted(entries, key=lambda x: parse_version_tuple(x.version))
create_release_notifications_for_versions(db_session, entries)
except Exception as e:
logger.error(f"Failed to check release notes: {e}")
# Update timestamp even on failure to prevent retry storms
# We don't save etag on failure to allow retry with conditional request
save_fetch_metadata(None)
finally:
# Always release the lock
if lock.owned():
lock.release()

View File

@@ -22,6 +22,9 @@ from onyx.tools.tool_implementations.open_url.models import WebContentProvider
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
OnyxWebCrawler,
)
from onyx.tools.tool_implementations.open_url.utils import (
filter_web_contents_with_no_title_or_content,
)
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
from onyx.tools.tool_implementations.web_search.models import WebSearchProvider
from onyx.tools.tool_implementations.web_search.providers import (
@@ -30,6 +33,9 @@ from onyx.tools.tool_implementations.web_search.providers import (
from onyx.tools.tool_implementations.web_search.providers import (
build_search_provider_from_config,
)
from onyx.tools.tool_implementations.web_search.utils import (
filter_web_search_results_with_no_title_or_snippet,
)
from onyx.tools.tool_implementations.web_search.utils import (
truncate_search_result_content,
)
@@ -156,7 +162,10 @@ def _run_web_search(
status_code=502, detail="Web search provider failed to execute query."
) from exc
trimmed_results = list(search_results)[: request.max_results]
filtered_results = filter_web_search_results_with_no_title_or_snippet(
list(search_results)
)
trimmed_results = list(filtered_results)[: request.max_results]
for search_result in trimmed_results:
results.append(
LlmWebSearchResult(
@@ -180,7 +189,9 @@ def _open_urls(
provider_view, provider = _get_active_content_provider(db_session)
try:
docs = provider.contents(urls)
docs = filter_web_contents_with_no_title_or_content(
list(provider.contents(urls))
)
except HTTPException:
raise
except Exception as exc:

View File

@@ -29,6 +29,9 @@ from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderTestRequest
from onyx.server.manage.web_search.models import WebSearchProviderUpsertRequest
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.tools.tool_implementations.open_url.utils import (
filter_web_contents_with_no_title_or_content,
)
from onyx.tools.tool_implementations.web_search.providers import (
build_content_provider_from_config,
)
@@ -353,7 +356,9 @@ def test_content_provider(
# Actually test the API key by making a real content fetch call
try:
test_url = "https://example.com"
test_results = provider.contents([test_url])
test_results = filter_web_contents_with_no_title_or_content(
list(provider.contents([test_url]))
)
if not test_results or not any(
result.scrape_successful for result in test_results
):

View File

@@ -1,8 +1,6 @@
import asyncio
import datetime
import json
import os
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -18,8 +16,11 @@ from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_session_from_request
@@ -87,6 +88,7 @@ from onyx.server.query_and_chat.models import ChatSessionSummary
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import LLMOverride
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import PromptOverride
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SearchFeedbackRequest
@@ -105,7 +107,6 @@ from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.threadpool_concurrency import run_in_background
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -292,6 +293,18 @@ def get_chat_session(
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
]
try:
is_processing = is_chat_session_processing(session_id, get_redis_client())
# Edit the last message to indicate loading (Overriding default message value)
if is_processing and chat_message_details:
last_msg = chat_message_details[-1]
if last_msg.message_type == MessageType.ASSISTANT:
last_msg.message = "Message is loading... Please refresh the page soon."
except Exception:
logger.exception(
"An error occurred while checking if the chat session is processing"
)
# Every assistant message might have a set of tool calls associated with it, these need to be replayed back for the frontend
# Each list is the set of tool calls for the given assistant message.
replay_packet_lists: list[list[Packet]] = []
@@ -510,7 +523,7 @@ def handle_new_chat_message(
@router.post("/send-chat-message", response_model=None, tags=PUBLIC_API_TAGS)
async def handle_send_chat_message(
def handle_send_chat_message(
chat_message_req: SendMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accessible_user),
@@ -540,6 +553,11 @@ async def handle_send_chat_message(
event=MilestoneRecordType.RAN_QUERY,
)
# Override origin to API when authenticated via API key or PAT
# to prevent clients from polluting telemetry data
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -575,63 +593,34 @@ async def handle_send_chat_message(
# Note: LLM cost tracking is now handled in multi_llm.py
return result
# Use prod-cons pattern to continue processing even if request stops yielding
buffer: asyncio.Queue[str | None] = asyncio.Queue()
loop = asyncio.get_running_loop()
# Capture headers before spawning thread
litellm_headers = extract_headers(request.headers, LITELLM_PASS_THROUGH_HEADERS)
custom_tool_headers = get_custom_tool_additional_request_headers(request.headers)
def producer() -> None:
"""
Producer function that runs handle_stream_message_objects in a loop
and writes results to the buffer.
"""
# Streaming path, normal Onyx UI behavior
def stream_generator() -> Generator[str, None, None]:
state_container = ChatStateContainer()
try:
logger.debug("Producer started")
with get_session_with_current_tenant() as db_session:
for obj in handle_stream_message_objects(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
litellm_additional_headers=litellm_headers,
custom_tool_additional_headers=custom_tool_headers,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
external_state_container=state_container,
):
# Thread-safe put into the asyncio queue
loop.call_soon_threadsafe(
buffer.put_nowait, get_json_line(obj.model_dump())
)
yield get_json_line(obj.model_dump())
# Note: LLM cost tracking is now handled in multi_llm.py
except Exception as e:
logger.exception("Error in chat message streaming")
loop.call_soon_threadsafe(buffer.put_nowait, json.dumps({"error": str(e)}))
yield json.dumps({"error": str(e)})
finally:
# Signal end of stream
loop.call_soon_threadsafe(buffer.put_nowait, None)
logger.debug("Producer finished")
logger.debug("Stream generator finished")
async def stream_from_buffer() -> AsyncGenerator[str, None]:
"""
Async generator that reads from the buffer and yields to the client.
"""
try:
while True:
item = await buffer.get()
if item is None:
# End of stream signal
break
yield item
except asyncio.CancelledError:
logger.warning("Stream cancelled (Consumer disconnected)")
finally:
logger.debug("Stream consumer finished")
run_in_background(producer)
return StreamingResponse(stream_from_buffer(), media_type="text/event-stream")
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@router.put("/set-message-as-latest")

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from uuid import UUID
@@ -36,6 +37,17 @@ from onyx.server.query_and_chat.streaming_models import Packet
AUTO_PLACE_AFTER_LATEST_MESSAGE = -1
class MessageOrigin(str, Enum):
"""Origin of a chat message for telemetry tracking."""
WEBAPP = "webapp"
CHROME_EXTENSION = "chrome_extension"
API = "api"
SLACKBOT = "slackbot"
UNKNOWN = "unknown"
UNSET = "unset"
if TYPE_CHECKING:
pass
@@ -93,6 +105,9 @@ class SendMessageRequest(BaseModel):
deep_research: bool = False
# Origin of the message for telemetry tracking
origin: MessageOrigin = MessageOrigin.UNSET
# Placement information for the message in the conversation tree:
# - -1: auto-place after latest message in chain
# - null: regeneration from root (first message)
@@ -184,6 +199,9 @@ class CreateChatMessageRequest(ChunkContext):
deep_research: bool = False
# Origin of the message for telemetry tracking
origin: MessageOrigin = MessageOrigin.UNKNOWN
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:

View File

@@ -60,6 +60,7 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.query_and_chat.models import DocumentSearchPagination
from onyx.server.query_and_chat.models import DocumentSearchRequest
from onyx.server.query_and_chat.models import DocumentSearchResponse
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import OneShotQARequest
from onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.server.query_and_chat.models import SearchSessionDetailResponse
@@ -251,6 +252,7 @@ def get_answer_stream(
)
# Also creates a new chat session
# Origin is hardcoded to API since this endpoint is only accessible via API calls
request = prepare_chat_message_request(
message_text=combined_message,
user=user,
@@ -261,6 +263,7 @@ def get_answer_stream(
rerank_settings=query_request.rerank_settings,
db_session=db_session,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
origin=MessageOrigin.API,
)
packets = stream_chat_message_objects(

View File

@@ -23,6 +23,7 @@ from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import IntermediateReportDelta
from onyx.server.query_and_chat.streaming_models import IntermediateReportStart
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
@@ -35,6 +36,7 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
@@ -207,6 +209,7 @@ def create_research_agent_packets(
"""Create packets for research agent tool calls.
This recreates the packet structure that ResearchAgentRenderer expects:
- ResearchAgentStart with the research task
- IntermediateReportStart to signal report begins
- IntermediateReportDelta with the report content (if available)
- SectionEnd to mark completion
"""
@@ -222,6 +225,14 @@ def create_research_agent_packets(
# Emit report content if available
if report_content:
# Emit IntermediateReportStart before delta
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=IntermediateReportStart(),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
@@ -381,10 +392,17 @@ def translate_assistant_message_to_packets(
)
)
# Process each tool call in this turn
# Process each tool call in this turn (single pass).
# We buffer packets for the turn so we can conditionally prepend a TopLevelBranching
# packet (which must appear before any tool output in the turn).
research_agent_count = 0
turn_tool_packets: list[Packet] = []
for tool_call in tool_calls_in_turn:
# Here we do a try because some tools may get deleted before the session is reloaded.
try:
tool = get_tool_by_id(tool_call.tool_id, db_session)
if tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
research_agent_count += 1
# Handle different tool types
if tool.in_code_tool_id in [
@@ -398,7 +416,7 @@ def translate_assistant_message_to_packets(
translate_db_search_doc_to_saved_search_doc(doc)
for doc in tool_call.search_docs
]
packet_list.extend(
turn_tool_packets.extend(
create_search_packets(
search_queries=queries,
search_docs=search_docs,
@@ -418,7 +436,7 @@ def translate_assistant_message_to_packets(
urls = cast(
list[str], tool_call.tool_call_arguments.get("urls", [])
)
packet_list.extend(
turn_tool_packets.extend(
create_fetch_packets(
fetch_docs,
urls,
@@ -433,7 +451,7 @@ def translate_assistant_message_to_packets(
GeneratedImage(**img)
for img in tool_call.generated_images
]
packet_list.extend(
turn_tool_packets.extend(
create_image_generation_packets(
images, turn_num, tab_index=tool_call.tab_index
)
@@ -446,7 +464,7 @@ def translate_assistant_message_to_packets(
tool_call.tool_call_arguments.get(RESEARCH_AGENT_TASK_KEY)
or "Could not fetch saved research task.",
)
packet_list.extend(
turn_tool_packets.extend(
create_research_agent_packets(
research_task=research_task,
report_content=tool_call.tool_call_response,
@@ -457,7 +475,7 @@ def translate_assistant_message_to_packets(
else:
# Custom tool or unknown tool
packet_list.extend(
turn_tool_packets.extend(
create_custom_tool_packets(
tool_name=tool.display_name or tool.name,
response_type="text",
@@ -471,6 +489,18 @@ def translate_assistant_message_to_packets(
logger.warning(f"Error processing tool call {tool_call.id}: {e}")
continue
if research_agent_count > 1:
# Emit TopLevelBranching before processing any tool output in the turn.
packet_list.append(
Packet(
placement=Placement(turn_index=turn_num),
obj=TopLevelBranching(
num_parallel_branches=research_agent_count
),
)
)
packet_list.extend(turn_tool_packets)
# Determine the next turn_index for the final message
# It should come after all tool calls
max_tool_turn = 0
@@ -539,9 +569,18 @@ def translate_assistant_message_to_packets(
if citation_info_list:
final_turn_index = max(final_turn_index, citation_turn_index)
# Determine stop reason - check if message indicates user cancelled
stop_reason: str | None = None
if chat_message.message:
if "Generation was stopped" in chat_message.message:
stop_reason = "user_cancelled"
# Add overall stop packet at the end
packet_list.append(
Packet(placement=Placement(turn_index=final_turn_index), obj=OverallStop())
Packet(
placement=Placement(turn_index=final_turn_index),
obj=OverallStop(stop_reason=stop_reason),
)
)
return packet_list

View File

@@ -410,7 +410,7 @@ def run_research_agent_call(
most_recent_reasoning = llm_step_result.reasoning
continue
else:
tool_responses, citation_mapping = run_tool_calls(
parallel_tool_call_results = run_tool_calls(
tool_calls=tool_calls,
tools=current_tools,
message_history=msg_history,
@@ -424,6 +424,10 @@ def run_research_agent_call(
# May be better to not do this step, hard to say, needs to be tested
skip_search_query_expansion=False,
)
tool_responses = parallel_tool_call_results.tool_responses
citation_mapping = (
parallel_tool_call_results.updated_citation_mapping
)
if tool_calls and not tool_responses:
failure_messages = create_tool_call_failure_messages(

View File

@@ -25,6 +25,17 @@ TOOL_CALL_MSG_FUNC_NAME = "function_name"
TOOL_CALL_MSG_ARGUMENTS = "arguments"
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def __init__(self, message: str, llm_facing_message: str):
# This is the full error message which is used for tracing
super().__init__(message)
# LLM made tool calls are acceptable and not flow terminating, this is the message
# which will populate the tool response.
self.llm_facing_message = llm_facing_message
class SearchToolUsage(str, Enum):
DISABLED = "disabled"
ENABLED = "enabled"
@@ -77,6 +88,11 @@ class ToolResponse(BaseModel):
tool_call: ToolCallKickoff | None = None
class ParallelToolCallResponse(BaseModel):
tool_responses: list[ToolResponse]
updated_citation_mapping: dict[int, str]
class ToolRunnerResponse(BaseModel):
tool_run_kickoff: ToolCallKickoff | None = None
tool_response: ToolResponse | None = None

View File

@@ -34,6 +34,9 @@ from onyx.tools.tool_implementations.open_url.url_normalization import (
_default_url_normalizer,
)
from onyx.tools.tool_implementations.open_url.url_normalization import normalize_url
from onyx.tools.tool_implementations.open_url.utils import (
filter_web_contents_with_no_title_or_content,
)
from onyx.tools.tool_implementations.web_search.providers import (
get_default_content_provider,
)
@@ -520,6 +523,11 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
)
return ToolResponse(rich_response=None, llm_facing_response=failure_msg)
for section in inference_sections:
chunk = section.center_chunk
if not chunk.semantic_identifier and chunk.source_links:
chunk.semantic_identifier = chunk.source_links[0]
# Convert sections to search docs, preserving source information
search_docs = convert_inference_sections_to_search_docs(
inference_sections, is_internet=False
@@ -766,15 +774,23 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
if not urls:
return [], []
web_contents = self._provider.contents(urls)
raw_web_contents = self._provider.contents(urls)
# Treat "no title and no content" as a failure for that URL, but don't
# include the empty entry in downstream prompting/sections.
failed_urls: list[str] = [
content.link
for content in raw_web_contents
if not content.title.strip() and not content.full_content.strip()
]
web_contents = filter_web_contents_with_no_title_or_content(raw_web_contents)
sections: list[InferenceSection] = []
failed_urls: list[str] = []
for content in web_contents:
# Check if content is insufficient (e.g., "Loading..." or too short)
text_stripped = content.full_content.strip()
is_insufficient = (
not text_stripped
# TODO: Likely a behavior of our scraper, understand why this special pattern occurs
or text_stripped.lower() == "loading..."
or len(text_stripped) < 50
)
@@ -786,6 +802,9 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
):
sections.append(inference_section_from_internet_page_scrape(content))
else:
# TODO: Slight improvement - if failed URL reasons are passed back to the LLM
# for example, if it tries to crawl Reddit and fails, it should know (probably) that this error would
# happen again if it tried to crawl Reddit again.
failed_urls.append(content.link or "")
return sections, failed_urls

View File

@@ -0,0 +1,17 @@
from onyx.tools.tool_implementations.open_url.models import WebContent
def filter_web_contents_with_no_title_or_content(
contents: list[WebContent],
) -> list[WebContent]:
"""Filter out content entries that have neither a title nor any extracted text.
Some content providers can return placeholder/partial entries that only include a URL.
Downstream uses these fields for display + prompting; drop empty ones centrally
rather than duplicating checks across provider clients.
"""
filtered: list[WebContent] = []
for content in contents:
if content.title.strip() or content.full_content.strip():
filtered.append(content)
return filtered

View File

@@ -252,14 +252,14 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Store session factory instead of session for thread-safety
# When tools are called in parallel, each thread needs its own session
# TODO ensure this works!!!
self._session_bind = db_session.get_bind()
self._session_factory = sessionmaker(bind=self._session_bind)
self._id = tool_id
def _get_thread_safe_session(self) -> Session:
"""Create a new database session for the current thread.
"""Create a new database session for the current thread. Note this is only safe for the ORM caches/identity maps,
pending objects, flush state, etc. But it is still using the same underlying database connection.
This ensures thread-safety when the search tool is called in parallel.
Each parallel execution gets its own isolated database session with

View File

@@ -19,7 +19,6 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO can probably break this up
class ExaClient(WebSearchProvider, WebContentProvider):
def __init__(self, api_key: str, num_results: int = 10) -> None:
self.exa = Exa(api_key=api_key)
@@ -41,20 +40,25 @@ class ExaClient(WebSearchProvider, WebContentProvider):
num_results=self._num_results,
)
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
results: list[WebSearchResult] = []
for result in response.results:
title = (result.title or "").strip()
snippet = (result.highlights[0] if result.highlights else "").strip()
results.append(
WebSearchResult(
title=title,
link=result.url,
snippet=snippet,
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
)
for result in response.results
]
return results
def test_connection(self) -> dict[str, str]:
try:
@@ -93,16 +97,23 @@ class ExaClient(WebSearchProvider, WebContentProvider):
livecrawl="preferred",
)
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
# Exa can return partial/empty content entries; skip those to avoid
# downstream prompt + UI pollution.
contents: list[WebContent] = []
for result in response.results:
title = (result.title or "").strip()
full_content = (result.text or "").strip()
contents.append(
WebContent(
title=title,
link=result.url,
full_content=full_content,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
)
for result in response.results
]
return contents

View File

@@ -47,20 +47,28 @@ class SerperClient(WebSearchProvider, WebContentProvider):
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
organic_results = results.get("organic") or []
organic_results = filter(lambda result: "link" in result, organic_results)
validated_results: list[WebSearchResult] = []
for result in organic_results:
link = (result.get("link") or "").strip()
if not link:
continue
return [
WebSearchResult(
title=result.get("title", ""),
link=result.get("link"),
snippet=result.get("snippet", ""),
author=None,
published_date=None,
title = (result.get("title") or "").strip()
snippet = (result.get("snippet") or "").strip()
validated_results.append(
WebSearchResult(
title=title,
link=link,
snippet=snippet,
author=None,
published_date=None,
)
)
for result in organic_results
]
return validated_results
def test_connection(self) -> dict[str, str]:
try:

View File

@@ -6,6 +6,22 @@ from onyx.tools.tool_implementations.web_search.models import WEB_SEARCH_PREFIX
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
def filter_web_search_results_with_no_title_or_snippet(
results: list[WebSearchResult],
) -> list[WebSearchResult]:
"""Filter out results that have neither a title nor a snippet.
Some providers can return entries that only include a URL. Downstream uses
titles/snippets for display and prompting, so we drop those empty entries
centrally (rather than duplicating the check in each client).
"""
filtered: list[WebSearchResult] = []
for result in results:
if result.title.strip() or result.snippet.strip():
filtered.append(result)
return filtered
def truncate_search_result_content(content: str, max_chars: int = 15000) -> str:
"""Truncate search result content to a maximum number of characters"""
if len(content) <= max_chars:

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
from typing import cast
@@ -15,6 +16,7 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool_implementations.utils import (
@@ -25,6 +27,9 @@ from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.tools.tool_implementations.web_search.providers import (
build_search_provider_from_config,
)
from onyx.tools.tool_implementations.web_search.utils import (
filter_web_search_results_with_no_title_or_snippet,
)
from onyx.tools.tool_implementations.web_search.utils import (
inference_section_from_internet_search_result,
)
@@ -124,13 +129,28 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
)
)
def _execute_single_search(
def _safe_execute_single_search(
self,
query: str,
provider: Any,
) -> list[WebSearchResult]:
"""Execute a single search query and return results."""
return list(provider.search(query))[:DEFAULT_MAX_RESULTS]
) -> tuple[list[WebSearchResult] | None, str | None]:
"""Execute a single search query and return results with error capture.
Returns:
A tuple of (results, error_message). If successful, error_message is None.
If failed, results is None and error_message contains the error.
"""
try:
raw_results = list(provider.search(query))
filtered_results = filter_web_search_results_with_no_title_or_snippet(
raw_results
)
results = filtered_results[:DEFAULT_MAX_RESULTS]
return (results, None)
except Exception as e:
error_msg = str(e)
logger.warning(f"Web search query '{query}' failed: {error_msg}")
return (None, error_msg)
def run(
self,
@@ -149,22 +169,46 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
)
)
# Perform searches in parallel
# Perform searches in parallel with error capture
functions_with_args = [
(self._execute_single_search, (query, self._provider)) for query in queries
(self._safe_execute_single_search, (query, self._provider))
for query in queries
]
search_results_per_query: list[list[WebSearchResult]] = (
run_functions_tuples_in_parallel(
functions_with_args,
allow_failures=True,
)
search_results_with_errors: list[
tuple[list[WebSearchResult] | None, str | None]
] = run_functions_tuples_in_parallel(
functions_with_args,
allow_failures=False, # Our wrapper handles errors internally
)
# Separate successful results from failures
valid_results: list[list[WebSearchResult]] = []
failed_queries: dict[str, str] = {}
for query, (results, error) in zip(queries, search_results_with_errors):
if error is not None:
failed_queries[query] = error
elif results is not None:
valid_results.append(results)
# Log partial failures but continue if we have at least one success
if failed_queries and valid_results:
logger.warning(
f"Web search partial failure: {len(failed_queries)}/{len(queries)} "
f"queries failed. Failed queries: {json.dumps(failed_queries)}"
)
# If all queries failed, raise ToolCallException with details
if not valid_results:
error_details = json.dumps(failed_queries, indent=2)
raise ToolCallException(
message=f"All web search queries failed: {error_details}",
llm_facing_message=(
f"All web search queries failed. Query failures:\n{error_details}"
),
)
# Interweave top results from each query in round-robin fashion
# Filter out None results from failures
valid_results = [
results for results in search_results_per_query if results is not None
]
all_search_results: list[WebSearchResult] = []
if valid_results:
@@ -191,8 +235,15 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
if not added_any:
break
# This should be a very rare case and is due to not failing loudly enough in the search provider implementation.
if not all_search_results:
raise RuntimeError("No search results found.")
raise ToolCallException(
message="Web search queries succeeded but returned no results",
llm_facing_message=(
"Web search completed but found no results for the given queries. "
"Try rephrasing or using different search terms."
),
)
# Convert search results to InferenceSections with rank-based scoring
inference_sections = [

View File

@@ -11,7 +11,9 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.interface import Tool
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import ParallelToolCallResponse
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
@@ -27,6 +29,7 @@ logger = setup_logger()
QUERIES_FIELD = "queries"
URLS_FIELD = "urls"
GENERIC_TOOL_ERROR_MESSAGE = "Tool failed with error: {error}"
# Mapping of tool name to the field that should be merged when multiple calls exist
MERGEABLE_TOOL_FIELDS: dict[str, str] = {
@@ -91,7 +94,7 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
return merged_calls
def _run_single_tool(
def _safe_run_single_tool(
tool: Tool,
tool_call: ToolCallKickoff,
override_kwargs: Any,
@@ -99,7 +102,18 @@ def _run_single_tool(
"""Execute a single tool and return its response.
This function is designed to be run in parallel via run_functions_tuples_in_parallel.
Exception handling:
- ToolCallException: Expected errors from tool execution (e.g., invalid input,
API failures). Uses the exception's llm_facing_message for LLM consumption.
- Other exceptions: Unexpected errors. Uses a generic error message.
In all cases (success or failure):
- SectionEnd packet is emitted to signal tool completion
- tool_call is set on the response for downstream processing
"""
tool_response: ToolResponse | None = None
with function_span(tool.name) as span_fn:
span_fn.span_data.input = str(tool_call.tool_args)
try:
@@ -109,19 +123,47 @@ def _run_single_tool(
**tool_call.tool_args,
)
span_fn.span_data.output = tool_response.llm_facing_response
except Exception as e:
logger.error(f"Error running tool {tool.name}: {e}")
except ToolCallException as e:
# ToolCallException is an expected error from tool execution
# Use llm_facing_message which is specifically designed for LLM consumption
logger.error(f"Tool call error for {tool.name}: {e}")
tool_response = ToolResponse(
rich_response=None,
llm_facing_response="Tool execution failed with: " + str(e),
llm_facing_response=GENERIC_TOOL_ERROR_MESSAGE.format(
error=e.llm_facing_message
),
)
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
message="Tool call error (expected)",
data={
"tool_name": tool.name,
"tool_call_id": tool_call.tool_call_id,
"tool_args": tool_call.tool_args,
"error": str(e),
"llm_facing_message": e.llm_facing_message,
"stack_trace": traceback.format_exc(),
"error_type": "ToolCallException",
},
)
)
except Exception as e:
# Unexpected error during tool execution
logger.error(f"Unexpected error running tool {tool.name}: {e}")
tool_response = ToolResponse(
rich_response=None,
llm_facing_response=GENERIC_TOOL_ERROR_MESSAGE.format(error=str(e)),
)
_error_tracing.attach_error_to_current_span(
SpanError(
message="Tool execution error (unexpected)",
data={
"tool_name": tool.name,
"tool_call_id": tool_call.tool_call_id,
"tool_args": tool_call.tool_args,
"error": str(e),
"stack_trace": traceback.format_exc(),
"error_type": type(e).__name__,
},
)
)
@@ -153,35 +195,52 @@ def run_tool_calls(
max_concurrent_tools: int | None = None,
# Skip query expansion for repeat search tool calls
skip_search_query_expansion: bool = False,
) -> tuple[list[ToolResponse], dict[int, str]]:
"""Run multiple tool calls in parallel and update citation mappings.
) -> ParallelToolCallResponse:
"""Run (optionally merged) tool calls in parallel and update citation mappings.
Merges tool calls for SearchTool, WebSearchTool, and OpenURLTool before execution.
All tools are executed in parallel, and citation mappings are updated
from search tool responses.
Before execution, tool calls for `SearchTool`, `WebSearchTool`, and `OpenURLTool`
are merged so repeated calls are collapsed into a single call per tool:
- `SearchTool` / `WebSearchTool`: merge the `queries` list
- `OpenURLTool`: merge the `urls` list
Tools are executed in parallel (threadpool). For tools that generate citations,
each tool call is assigned a **distinct** `starting_citation_num` range to avoid
citation number collisions when running concurrently (the range is advanced by
100 per tool call).
The provided `citation_mapping` may be mutated in-place: any new
`SearchDocsResponse.citation_mapping` entries are merged into it.
Args:
tool_calls: List of tool calls to execute
tools: List of available tools
message_history: Chat message history for context
memories: User memories, if available
user_info: User information string, if available
citation_mapping: Current citation number to URL mapping
next_citation_num: Next citation number to use
tool_calls: List of tool calls to execute.
tools: List of available tool instances.
message_history: Chat message history (used to find the most recent user query
for `SearchTool` override kwargs).
memories: User memories, if available (passed through to `SearchTool`).
user_info: User information string, if available (passed through to `SearchTool`).
citation_mapping: Current citation number to URL mapping. May be updated with
new citations produced by search tools.
next_citation_num: The next citation number to allocate from.
max_concurrent_tools: Max number of tools to run in this batch. If set, any
tool calls after this limit are dropped (not queued).
skip_search_query_expansion: Whether to skip query expansion for search tools
skip_search_query_expansion: Whether to skip query expansion for `SearchTool`
(intended for repeated search calls within the same chat turn).
Returns:
A tuple containing:
- List of ToolResponse objects (each with tool_call set)
- Updated citation mapping dictionary
A `ParallelToolCallResponse` containing:
- `tool_responses`: `ToolResponse` objects for successfully dispatched tool calls
(each has `tool_call` set). If a tool execution fails at the threadpool layer,
its entry will be omitted.
- `updated_citation_mapping`: The updated citation mapping dictionary.
"""
# Merge tool calls for SearchTool and WebSearchTool
# Merge tool calls for SearchTool, WebSearchTool, and OpenURLTool
merged_tool_calls = _merge_tool_calls(tool_calls)
if not merged_tool_calls:
return [], citation_mapping
return ParallelToolCallResponse(
tool_responses=[],
updated_citation_mapping=citation_mapping,
)
tools_by_name = {tool.name: tool for tool in tools}
@@ -196,7 +255,10 @@ def run_tool_calls(
# Apply safety cap (drop tool calls beyond the cap)
if max_concurrent_tools is not None:
if max_concurrent_tools <= 0:
return [], citation_mapping
return ParallelToolCallResponse(
tool_responses=[],
updated_citation_mapping=citation_mapping,
)
filtered_tool_calls = filtered_tool_calls[:max_concurrent_tools]
# Get starting citation number from citation processor to avoid conflicts with project files
@@ -269,24 +331,29 @@ def run_tool_calls(
# Run all tools in parallel
functions_with_args = [
(_run_single_tool, (tool, tool_call, override_kwargs))
(_safe_run_single_tool, (tool, tool_call, override_kwargs))
for tool, tool_call, override_kwargs in tool_run_params
]
tool_responses: list[ToolResponse] = run_functions_tuples_in_parallel(
tool_run_results: list[ToolResponse | None] = run_functions_tuples_in_parallel(
functions_with_args,
allow_failures=True, # Continue even if some tools fail
max_workers=max_concurrent_tools,
)
# Process results and update citation_mapping
for tool_response in tool_responses:
if tool_response and isinstance(
tool_response.rich_response, SearchDocsResponse
):
new_citations = tool_response.rich_response.citation_mapping
for result in tool_run_results:
if result is None:
continue
if result and isinstance(result.rich_response, SearchDocsResponse):
new_citations = result.rich_response.citation_mapping
if new_citations:
# Merge new citations into the existing mapping
citation_mapping.update(new_citations)
return tool_responses, citation_mapping
tool_responses = [result for result in tool_run_results if result is not None]
return ParallelToolCallResponse(
tool_responses=tool_responses,
updated_citation_mapping=citation_mapping,
)

View File

@@ -34,6 +34,7 @@ from scripts.tenant_cleanup.cleanup_utils import execute_control_plane_query
from scripts.tenant_cleanup.cleanup_utils import find_worker_pod
from scripts.tenant_cleanup.cleanup_utils import get_tenant_status
from scripts.tenant_cleanup.cleanup_utils import read_tenant_ids_from_csv
from scripts.tenant_cleanup.cleanup_utils import TenantNotFoundInControlPlaneError
def signal_handler(signum: int, frame: object) -> None:
@@ -418,6 +419,9 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
"""
print(f"Starting cleanup for tenant: {tenant_id}")
# Track if tenant was not found in control plane (for force mode)
tenant_not_found_in_control_plane = False
# Check tenant status first
print(f"\n{'=' * 80}")
try:
@@ -457,8 +461,25 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
if response.lower() != "yes":
print("Cleanup aborted - could not verify tenant status")
return False
except TenantNotFoundInControlPlaneError as e:
# Tenant/table not found in control plane
error_str = str(e)
print(f"⚠️ WARNING: Tenant not found in control plane: {error_str}")
tenant_not_found_in_control_plane = True
if force:
print(
"[FORCE MODE] Tenant not found in control plane - continuing with dataplane cleanup only"
)
else:
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Cleanup aborted - tenant not found in control plane")
return False
except Exception as e:
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
# Other errors (not "not found")
error_str = str(e)
print(f"⚠️ WARNING: Failed to check tenant status: {error_str}")
if force:
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
@@ -516,8 +537,14 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
else:
print("Step 2 skipped by user")
# Step 3: Clean up control plane
if confirm_step(
# Step 3: Clean up control plane (skip if tenant not found in control plane with --force)
if tenant_not_found_in_control_plane:
print(f"\n{'=' * 80}")
print(
"Step 3/3: Skipping control plane cleanup (tenant not found in control plane)"
)
print(f"{'=' * 80}\n")
elif confirm_step(
"Step 3/3: Delete control plane records (tenant_notification, tenant_config, subscription, tenant)",
force,
):

View File

@@ -7,6 +7,10 @@ from dataclasses import dataclass
from pathlib import Path
class TenantNotFoundInControlPlaneError(Exception):
"""Exception raised when tenant/table is not found in control plane."""
@dataclass
class ControlPlaneConfig:
"""Configuration for connecting to the control plane database."""
@@ -136,6 +140,9 @@ def get_tenant_status(tenant_id: str) -> str | None:
Returns:
Tenant status string (e.g., 'GATED_ACCESS', 'ACTIVE') or None if not found
Raises:
TenantNotFoundInControlPlaneError: If the tenant table/relation does not exist
"""
print(f"Fetching tenant status for tenant: {tenant_id}")
@@ -152,15 +159,18 @@ def get_tenant_status(tenant_id: str) -> str | None:
return status
else:
print("⚠ Tenant not found in control plane")
return None
raise TenantNotFoundInControlPlaneError(
f"Tenant {tenant_id} not found in control plane database"
)
except TenantNotFoundInControlPlaneError:
# Re-raise without wrapping
raise
except subprocess.CalledProcessError as e:
error_msg = e.stderr if e.stderr else str(e)
print(
f"✗ Failed to get tenant status for {tenant_id}: {e}",
f"✗ Failed to get tenant status for {tenant_id}: {error_msg}",
file=sys.stderr,
)
if e.stderr:
print(f" Error details: {e.stderr}", file=sys.stderr)
return None

View File

@@ -5,10 +5,9 @@ All queries run directly from pods.
Supports two-cluster architecture (data plane and control plane in separate clusters).
Usage:
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py <tenant_id> [--force]
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py --csv <csv_file_path> [--force]
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py <tenant_id> \
--data-plane-context <context> --control-plane-context <context> [--force]
With explicit contexts:
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py --csv <csv_file_path> \
--data-plane-context <context> --control-plane-context <context> [--force]
"""
@@ -30,6 +29,10 @@ from scripts.tenant_cleanup.no_bastion_cleanup_utils import find_background_pod
from scripts.tenant_cleanup.no_bastion_cleanup_utils import find_worker_pod
from scripts.tenant_cleanup.no_bastion_cleanup_utils import get_tenant_status
from scripts.tenant_cleanup.no_bastion_cleanup_utils import read_tenant_ids_from_csv
from scripts.tenant_cleanup.no_bastion_cleanup_utils import (
TenantNotFoundInControlPlaneError,
)
# Global lock for thread-safe operations
_print_lock: Lock = Lock()
@@ -41,12 +44,12 @@ def signal_handler(signum: int, frame: object) -> None:
sys.exit(1)
def setup_scripts_on_pod(pod_name: str, context: str | None = None) -> None:
def setup_scripts_on_pod(pod_name: str, context: str) -> None:
"""Copy all required scripts to the pod once at the beginning.
Args:
pod_name: Pod to copy scripts to
context: Optional kubectl context
context: kubectl context for the cluster
"""
print("Setting up scripts on pod (one-time operation)...")
@@ -66,9 +69,7 @@ def setup_scripts_on_pod(pod_name: str, context: str | None = None) -> None:
if not local_file.exists():
raise FileNotFoundError(f"Script not found: {local_file}")
cmd_cp = ["kubectl", "cp"]
if context:
cmd_cp.extend(["--context", context])
cmd_cp = ["kubectl", "cp", "--context", context]
cmd_cp.extend([str(local_file), f"{pod_name}:{remote_path}"])
subprocess.run(cmd_cp, check=True, capture_output=True)
@@ -76,15 +77,13 @@ def setup_scripts_on_pod(pod_name: str, context: str | None = None) -> None:
print("✓ All scripts copied to pod")
def get_tenant_index_name(
pod_name: str, tenant_id: str, context: str | None = None
) -> str:
def get_tenant_index_name(pod_name: str, tenant_id: str, context: str) -> str:
"""Get the default index name for the given tenant by running script on pod.
Args:
pod_name: Data plane pod to execute on
tenant_id: Tenant ID to process
context: Optional kubectl context for data plane cluster
context: kubectl context for data plane cluster
"""
print(f"Getting default index name for tenant: {tenant_id}")
@@ -100,9 +99,7 @@ def get_tenant_index_name(
try:
# Copy script to pod
print(" Copying script to pod...")
cmd_cp = ["kubectl", "cp"]
if context:
cmd_cp.extend(["--context", context])
cmd_cp = ["kubectl", "cp", "--context", context]
cmd_cp.extend(
[
str(index_name_script),
@@ -118,12 +115,9 @@ def get_tenant_index_name(
# Execute script on pod
print(" Executing script on pod...")
cmd_exec = ["kubectl", "exec"]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(
[
pod_name,
"--",
"python",
"/tmp/get_tenant_index_name.py",
@@ -168,25 +162,20 @@ def get_tenant_index_name(
raise
def get_tenant_users(
pod_name: str, tenant_id: str, context: str | None = None
) -> list[str]:
def get_tenant_users(pod_name: str, tenant_id: str, context: str) -> list[str]:
"""Get list of user emails from the tenant's data plane schema.
Args:
pod_name: Data plane pod to execute on
tenant_id: Tenant ID to process
context: Optional kubectl context for data plane cluster
context: kubectl context for data plane cluster
"""
# Script is already on pod from setup_scripts_on_pod()
try:
# Execute script on pod
cmd_exec = ["kubectl", "exec"]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(
[
pod_name,
"--",
"python",
"/tmp/get_tenant_users.py",
@@ -233,25 +222,20 @@ def get_tenant_users(
return []
def check_documents_deleted(
pod_name: str, tenant_id: str, context: str | None = None
) -> None:
def check_documents_deleted(pod_name: str, tenant_id: str, context: str) -> None:
"""Check if all documents and connector credential pairs have been deleted.
Args:
pod_name: Data plane pod to execute on
tenant_id: Tenant ID to process
context: Optional kubectl context for data plane cluster
context: kubectl context for data plane cluster
"""
# Script is already on pod from setup_scripts_on_pod()
try:
# Execute script on pod
cmd_exec = ["kubectl", "exec"]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(
[
pod_name,
"--",
"python",
"/tmp/check_documents_deleted.py",
@@ -305,25 +289,20 @@ def check_documents_deleted(
raise
def drop_data_plane_schema(
pod_name: str, tenant_id: str, context: str | None = None
) -> None:
def drop_data_plane_schema(pod_name: str, tenant_id: str, context: str) -> None:
"""Drop the PostgreSQL schema for the given tenant by running script on pod.
Args:
pod_name: Data plane pod to execute on
tenant_id: Tenant ID to process
context: Optional kubectl context for data plane cluster
context: kubectl context for data plane cluster
"""
# Script is already on pod from setup_scripts_on_pod()
try:
# Execute script on pod
cmd_exec = ["kubectl", "exec"]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(
[
pod_name,
"--",
"python",
"/tmp/cleanup_tenant_schema.py",
@@ -366,14 +345,14 @@ def drop_data_plane_schema(
def cleanup_control_plane(
pod_name: str, tenant_id: str, context: str | None = None, force: bool = False
pod_name: str, tenant_id: str, context: str, force: bool = False
) -> None:
"""Clean up control plane data via pod queries.
Args:
pod_name: Control plane pod to execute on
tenant_id: Tenant ID to process
context: Optional kubectl context for control plane cluster
context: kubectl context for control plane cluster
force: Skip confirmations if True
"""
print(f"Cleaning up control plane data for tenant: {tenant_id}")
@@ -413,8 +392,8 @@ def cleanup_tenant(
tenant_id: str,
data_plane_pod: str,
control_plane_pod: str,
data_plane_context: str | None = None,
control_plane_context: str | None = None,
data_plane_context: str,
control_plane_context: str,
force: bool = False,
) -> bool:
"""Main cleanup function that orchestrates all cleanup steps.
@@ -423,12 +402,15 @@ def cleanup_tenant(
tenant_id: Tenant ID to process
data_plane_pod: Data plane pod for schema operations
control_plane_pod: Control plane pod for tenant record operations
data_plane_context: Optional kubectl context for data plane cluster
control_plane_context: Optional kubectl context for control plane cluster
data_plane_context: kubectl context for data plane cluster
control_plane_context: kubectl context for control plane cluster
force: Skip confirmations if True
"""
print(f"Starting cleanup for tenant: {tenant_id}")
# Track if tenant was not found in control plane (for force mode)
tenant_not_found_in_control_plane = False
# Check tenant status first (from control plane)
print(f"\n{'=' * 80}")
try:
@@ -470,8 +452,25 @@ def cleanup_tenant(
if response.lower() != "yes":
print("Cleanup aborted - could not verify tenant status")
return False
except TenantNotFoundInControlPlaneError as e:
# Tenant/table not found in control plane
error_str = str(e)
print(f"⚠️ WARNING: Tenant not found in control plane: {error_str}")
tenant_not_found_in_control_plane = True
if force:
print(
"[FORCE MODE] Tenant not found in control plane - continuing with dataplane cleanup only"
)
else:
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Cleanup aborted - tenant not found in control plane")
return False
except Exception as e:
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
# Other errors (not "not found")
error_str = str(e)
print(f"⚠️ WARNING: Failed to check tenant status: {error_str}")
if force:
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
@@ -528,8 +527,14 @@ def cleanup_tenant(
else:
print("Step 2 skipped by user")
# Step 3: Clean up control plane
if confirm_step(
# Step 3: Clean up control plane (skip if tenant not found in control plane with --force)
if tenant_not_found_in_control_plane:
print(f"\n{'=' * 80}")
print(
"Step 3/3: Skipping control plane cleanup (tenant not found in control plane)"
)
print(f"{'=' * 80}\n")
elif confirm_step(
"Step 3/3: Delete control plane records (tenant_notification, tenant_config, subscription, tenant)",
force,
):
@@ -560,12 +565,11 @@ def main() -> None:
if len(sys.argv) < 2:
print(
"Usage: PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py <tenant_id> [--force]"
"Usage: PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py <tenant_id> \\"
)
print(
" PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py --csv <csv_file_path> [--force]"
" --data-plane-context <context> --control-plane-context <context> [--force]"
)
print("\nTwo-cluster architecture (with explicit contexts):")
print(
" PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py --csv <csv_file_path> \\"
)
@@ -575,20 +579,20 @@ def main() -> None:
print("\nThis version runs ALL operations from pods (no bastion required)")
print("\nArguments:")
print(
" tenant_id The tenant ID to clean up (required if not using --csv)"
" tenant_id The tenant ID to clean up (required if not using --csv)"
)
print(
" --csv PATH Path to CSV file containing tenant IDs to clean up"
" --csv PATH Path to CSV file containing tenant IDs to clean up"
)
print(" --force Skip all confirmation prompts (optional)")
print(" --force Skip all confirmation prompts (optional)")
print(
" --concurrency N Process N tenants concurrently (default: 1)"
" --concurrency N Process N tenants concurrently (default: 1)"
)
print(
" --data-plane-context CTX Kubectl context for data plane cluster (optional)"
" --data-plane-context CTX Kubectl context for data plane cluster (required)"
)
print(
" --control-plane-context CTX Kubectl context for control plane cluster (optional)"
" --control-plane-context CTX Kubectl context for control plane cluster (required)"
)
sys.exit(1)
@@ -620,7 +624,7 @@ def main() -> None:
)
sys.exit(1)
# Parse contexts
# Parse contexts (required)
data_plane_context: str | None = None
control_plane_context: str | None = None
@@ -650,6 +654,21 @@ def main() -> None:
except ValueError:
pass
# Validate required contexts
if not data_plane_context:
print(
"Error: --data-plane-context is required",
file=sys.stderr,
)
sys.exit(1)
if not control_plane_context:
print(
"Error: --control-plane-context is required",
file=sys.stderr,
)
sys.exit(1)
# Check for CSV mode
if "--csv" in sys.argv:
try:

View File

@@ -10,19 +10,19 @@ import sys
from pathlib import Path
def find_worker_pod(context: str | None = None) -> str:
class TenantNotFoundInControlPlaneError(Exception):
"""Exception raised when tenant/table is not found in control plane."""
def find_worker_pod(context: str) -> str:
"""Find a user file processing worker pod using kubectl.
Args:
context: Optional kubectl context to use
context: kubectl context to use
"""
print(
f"Finding user file processing worker pod{f' in context {context}' if context else ''}..."
)
print(f"Finding user file processing worker pod in context {context}...")
cmd = ["kubectl", "get", "po"]
if context:
cmd.extend(["--context", context])
cmd = ["kubectl", "get", "po", "--context", context]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
@@ -43,17 +43,15 @@ def find_worker_pod(context: str | None = None) -> str:
raise RuntimeError("No running user file processing worker pod found")
def find_background_pod(context: str | None = None) -> str:
"""Find a background/api-server pod for control plane operations.
def find_background_pod(context: str) -> str:
"""Find a pod for control plane operations.
Args:
context: Optional kubectl context to use
context: kubectl context to use
"""
print(f"Finding background/api pod{f' in context {context}' if context else ''}...")
print(f"Finding control plane pod in context {context}...")
cmd = ["kubectl", "get", "po"]
if context:
cmd.extend(["--context", context])
cmd = ["kubectl", "get", "po", "--context", context]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
@@ -65,16 +63,15 @@ def find_background_pod(context: str | None = None) -> str:
random.shuffle(lines)
# Try to find api-server, background worker, or any celery worker
# Try to find control plane pods
for line in lines:
if (
any(
name in line
for name in [
"api-server",
"celery-worker-light",
"celery-worker-primary",
"background",
"background-processing-deployment",
"subscription-deployment",
"tenants-deployment",
]
)
and "Running" in line
@@ -106,20 +103,23 @@ def confirm_step(message: str, force: bool = False) -> bool:
def execute_control_plane_query_from_pod(
pod_name: str, query: str, context: str | None = None
pod_name: str, query: str, context: str
) -> dict:
"""Execute a SQL query against control plane database from within a pod.
Args:
pod_name: The Kubernetes pod name to execute from
query: The SQL query to execute
context: Optional kubectl context for control plane cluster
context: kubectl context for control plane cluster
Returns:
Dict with 'success' bool, 'stdout' str, and optional 'error' str
"""
# Create a Python script to run the query
# This script tries multiple environment variable patterns
# NOTE: whuang 01/08/2026: POSTGRES_CONTROL_* don't exist. This uses pattern 2 currently.
query_script = f'''
import os
from sqlalchemy import create_engine, text
@@ -175,9 +175,7 @@ with engine.connect() as conn:
script_path = "/tmp/control_plane_query.py"
try:
cmd_write = ["kubectl", "exec", pod_name]
if context:
cmd_write.extend(["--context", context])
cmd_write = ["kubectl", "exec", "--context", context, pod_name]
cmd_write.extend(
[
"--",
@@ -194,9 +192,7 @@ with engine.connect() as conn:
)
# Execute the script
cmd_exec = ["kubectl", "exec", pod_name]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(["--", "python", script_path])
result = subprocess.run(
@@ -220,19 +216,20 @@ with engine.connect() as conn:
}
def get_tenant_status(
pod_name: str, tenant_id: str, context: str | None = None
) -> str | None:
def get_tenant_status(pod_name: str, tenant_id: str, context: str) -> str | None:
"""
Get tenant status from control plane database via pod.
Args:
pod_name: The pod to execute the query from
tenant_id: The tenant ID to look up
context: Optional kubectl context for control plane cluster
context: kubectl context for control plane cluster
Returns:
Tenant status string (e.g., 'GATED_ACCESS', 'ACTIVE') or None if not found
Raises:
TenantNotFoundInControlPlaneError: If the tenant record is not found in the table
"""
print(f"Fetching tenant status for tenant: {tenant_id}")
@@ -241,8 +238,9 @@ def get_tenant_status(
result = execute_control_plane_query_from_pod(pod_name, query, context)
if not result["success"]:
error_msg = result.get("error", "Unknown error")
print(
f"✗ Failed to get tenant status for {tenant_id}: {result.get('error', 'Unknown error')}",
f"✗ Failed to get tenant status for {tenant_id}: {error_msg}",
file=sys.stderr,
)
return None
@@ -257,23 +255,27 @@ def get_tenant_status(
print(f"✓ Tenant status: {status}")
return status
# Tenant record not found in control plane table
print("⚠ Tenant not found in control plane")
return None
raise TenantNotFoundInControlPlaneError(
f"Tenant {tenant_id} not found in control plane database"
)
except TenantNotFoundInControlPlaneError:
# Re-raise without wrapping
raise
except (json.JSONDecodeError, KeyError, IndexError) as e:
print(f"✗ Failed to parse tenant status: {e}", file=sys.stderr)
return None
def execute_control_plane_delete(
pod_name: str, query: str, context: str | None = None
) -> bool:
def execute_control_plane_delete(pod_name: str, query: str, context: str) -> bool:
"""Execute a DELETE query against control plane database from pod.
Args:
pod_name: The pod to execute the query from
query: The DELETE query to execute
context: Optional kubectl context for control plane cluster
context: kubectl context for control plane cluster
Returns:
True if successful, False otherwise

View File

@@ -5,10 +5,9 @@ All queries run directly from pods.
Supports two-cluster architecture (data plane and control plane in separate clusters).
Usage:
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py <tenant_id> [--force]
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py --csv <csv_file_path> [--force] [--concurrency N]
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py <tenant_id> \
--data-plane-context <context> --control-plane-context <context> [--force]
With explicit contexts:
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py --csv <csv_file_path> \
--data-plane-context <context> --control-plane-context <context> [--force] [--concurrency N]
"""
@@ -26,6 +25,9 @@ from scripts.tenant_cleanup.no_bastion_cleanup_utils import find_background_pod
from scripts.tenant_cleanup.no_bastion_cleanup_utils import find_worker_pod
from scripts.tenant_cleanup.no_bastion_cleanup_utils import get_tenant_status
from scripts.tenant_cleanup.no_bastion_cleanup_utils import read_tenant_ids_from_csv
from scripts.tenant_cleanup.no_bastion_cleanup_utils import (
TenantNotFoundInControlPlaneError,
)
# Global lock for thread-safe printing
_print_lock: Lock = Lock()
@@ -37,15 +39,13 @@ def safe_print(*args: Any, **kwargs: Any) -> None:
print(*args, **kwargs)
def run_connector_deletion(
pod_name: str, tenant_id: str, context: str | None = None
) -> None:
def run_connector_deletion(pod_name: str, tenant_id: str, context: str) -> None:
"""Mark all connector credential pairs for deletion.
Args:
pod_name: Data plane pod to execute deletion on
tenant_id: Tenant ID to process
context: Optional kubectl context for data plane cluster
context: kubectl context for data plane cluster
"""
safe_print(" Marking all connector credential pairs for deletion...")
@@ -62,9 +62,7 @@ def run_connector_deletion(
try:
# Copy script to pod
cmd_cp = ["kubectl", "cp"]
if context:
cmd_cp.extend(["--context", context])
cmd_cp = ["kubectl", "cp", "--context", context]
cmd_cp.extend(
[
str(mark_deletion_script),
@@ -79,12 +77,9 @@ def run_connector_deletion(
)
# Execute script on pod
cmd_exec = ["kubectl", "exec"]
if context:
cmd_exec.extend(["--context", context])
cmd_exec = ["kubectl", "exec", "--context", context, pod_name]
cmd_exec.extend(
[
pod_name,
"--",
"python",
"/tmp/execute_connector_deletion.py",
@@ -118,8 +113,8 @@ def mark_tenant_connectors_for_deletion(
tenant_id: str,
data_plane_pod: str,
control_plane_pod: str,
data_plane_context: str | None = None,
control_plane_context: str | None = None,
data_plane_context: str,
control_plane_context: str,
force: bool = False,
) -> None:
"""Main function to mark all connectors for a tenant for deletion.
@@ -128,8 +123,8 @@ def mark_tenant_connectors_for_deletion(
tenant_id: Tenant ID to process
data_plane_pod: Data plane pod for connector operations
control_plane_pod: Control plane pod for status checks
data_plane_context: Optional kubectl context for data plane cluster
control_plane_context: Optional kubectl context for control plane cluster
data_plane_context: kubectl context for data plane cluster
control_plane_context: kubectl context for control plane cluster
force: Skip confirmations if True
"""
safe_print(f"Processing connectors for tenant: {tenant_id}")
@@ -174,6 +169,23 @@ def mark_tenant_connectors_for_deletion(
)
else:
raise RuntimeError(f"Could not verify tenant status for {tenant_id}")
except TenantNotFoundInControlPlaneError as e:
# Tenant/table not found in control plane
error_str = str(e)
safe_print(f"⚠️ WARNING: Tenant not found in control plane: {error_str}")
if force:
safe_print(
"[FORCE MODE] Tenant not found in control plane - continuing with connector deletion anyway"
)
else:
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
safe_print("Operation aborted - tenant not found in control plane")
raise RuntimeError(f"Tenant {tenant_id} not found in control plane")
except RuntimeError:
# Re-raise RuntimeError (from status checks above) without wrapping
raise
except Exception as e:
safe_print(f"⚠️ WARNING: Failed to check tenant status: {e}")
if not force:
@@ -205,16 +217,14 @@ def main() -> None:
if len(sys.argv) < 2:
print(
"Usage: PYTHONPATH=. python scripts/tenant_cleanup/"
"no_bastion_mark_connectors.py <tenant_id> [--force] [--concurrency N]"
"no_bastion_mark_connectors.py <tenant_id> \\"
)
print(
" --data-plane-context <context> --control-plane-context <context> [--force]"
)
print(
" PYTHONPATH=. python scripts/tenant_cleanup/"
"no_bastion_mark_connectors.py --csv <csv_file_path> "
"[--force] [--concurrency N]"
)
print("\nTwo-cluster architecture (with explicit contexts):")
print(
" PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py --csv <csv_file_path> \\"
"no_bastion_mark_connectors.py --csv <csv_file_path> \\"
)
print(
" --data-plane-context <context> --control-plane-context <context> [--force] [--concurrency N]"
@@ -222,20 +232,20 @@ def main() -> None:
print("\nThis version runs ALL operations from pods (no bastion required)")
print("\nArguments:")
print(
" tenant_id The tenant ID to process (required if not using --csv)"
" tenant_id The tenant ID to process (required if not using --csv)"
)
print(
" --csv PATH Path to CSV file containing tenant IDs to process"
" --csv PATH Path to CSV file containing tenant IDs to process"
)
print(" --force Skip all confirmation prompts (optional)")
print(" --force Skip all confirmation prompts (optional)")
print(
" --concurrency N Process N tenants concurrently (default: 1)"
" --concurrency N Process N tenants concurrently (default: 1)"
)
print(
" --data-plane-context CTX Kubectl context for data plane cluster (optional)"
" --data-plane-context CTX Kubectl context for data plane cluster (required)"
)
print(
" --control-plane-context CTX Kubectl context for control plane cluster (optional)"
" --control-plane-context CTX Kubectl context for control plane cluster (required)"
)
sys.exit(1)
@@ -243,7 +253,7 @@ def main() -> None:
force = "--force" in sys.argv
tenant_ids: list[str] = []
# Parse contexts
# Parse contexts (required)
data_plane_context: str | None = None
control_plane_context: str | None = None
@@ -273,6 +283,21 @@ def main() -> None:
except ValueError:
pass
# Validate required contexts
if not data_plane_context:
print(
"Error: --data-plane-context is required",
file=sys.stderr,
)
sys.exit(1)
if not control_plane_context:
print(
"Error: --control-plane-context is required",
file=sys.stderr,
)
sys.exit(1)
# Parse concurrency
concurrency: int = 1
if "--concurrency" in sys.argv:

View File

@@ -236,10 +236,10 @@ USAGE_LIMIT_LLM_COST_CENTS_PAID = int(
# Per-week chunks indexed limits
USAGE_LIMIT_CHUNKS_INDEXED_TRIAL = int(
os.environ.get("USAGE_LIMIT_CHUNKS_INDEXED_TRIAL", "10000")
os.environ.get("USAGE_LIMIT_CHUNKS_INDEXED_TRIAL", 100_000)
)
USAGE_LIMIT_CHUNKS_INDEXED_PAID = int(
os.environ.get("USAGE_LIMIT_CHUNKS_INDEXED_PAID", "50000")
os.environ.get("USAGE_LIMIT_CHUNKS_INDEXED_PAID", 1_000_000)
)
# Per-week API calls using API keys or Personal Access Tokens

View File

@@ -397,6 +397,7 @@ def test_anthropic_prompt_caching_reduces_costs(
not os.environ.get(VERTEX_LOCATION_ENV),
reason="VERTEX_LOCATION required for Vertex AI context caching (e.g., 'us-central1')",
)
@pytest.mark.skip(reason="Vertex AI prompt caching is disabled for now")
def test_google_genai_prompt_caching_reduces_costs(
db_session: Session,
) -> None:

View File

@@ -164,6 +164,87 @@ class ChatSessionManager:
return streamed_response
@staticmethod
def send_message_with_disconnect(
chat_session_id: UUID,
message: str,
disconnect_after_packets: int = 0,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
search_doc_ids: list[int] | None = None,
retrieval_options: RetrievalDetails | None = None,
query_override: str | None = None,
regenerate: bool | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
forced_tool_ids: list[int] | None = None,
) -> None:
"""
Send a message and simulate client disconnect before stream completes.
This is useful for testing how the server handles client disconnections
during streaming responses.
Args:
chat_session_id: The chat session ID
message: The message to send
disconnect_after_packets: Disconnect after receiving this many packets.
If None, disconnect_after_type must be specified.
disconnect_after_type: Disconnect after receiving a packet of this type
(e.g., "message_start", "search_tool_start"). If None,
disconnect_after_packets must be specified.
... (other standard message parameters)
Returns:
StreamedResponse containing data received before disconnect,
with is_disconnected=True flag set.
"""
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
message=message,
file_descriptors=file_descriptors or [],
search_doc_ids=search_doc_ids or [],
retrieval_options=retrieval_options,
rerank_settings=None,
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
forced_tool_ids=forced_tool_ids,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
packets_received = 0
with requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
headers=headers,
stream=True,
cookies=cookies,
) as response:
for line in response.iter_lines():
if not line:
continue
packets_received += 1
if packets_received > disconnect_after_packets:
break
return None
@staticmethod
def analyze_response(response: Response) -> StreamedResponse:
response_data = cast(

View File

@@ -18,25 +18,15 @@ import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
# Skip all tests in this module
pytestmark = pytest.mark.skip(reason="Auto LLM update tests temporarily disabled")
# How long to wait for the celery task to run and sync models
# This should be longer than AUTO_LLM_UPDATE_INTERVAL_SECONDS
MAX_WAIT_TIME_SECONDS = 60
MAX_WAIT_TIME_SECONDS = 120
POLL_INTERVAL_SECONDS = 5
@pytest.fixture(scope="module", autouse=True)
def reset_for_module() -> None:
"""Reset all data once before running any tests in this module."""
reset_all()
def _create_provider_with_api(
admin_user: DATestUser,
name: str,
@@ -142,6 +132,7 @@ def wait_for_model_sync(
def test_auto_mode_provider_gets_synced_from_github_config(
reset: None,
admin_user: DATestUser,
) -> None:
"""
@@ -156,7 +147,7 @@ def test_auto_mode_provider_gets_synced_from_github_config(
# First, get the GitHub config to know what models we should expect
github_config = get_auto_config(admin_user)
if github_config is None:
pytest.skip("GitHub config not found")
pytest.fail("GitHub config not found")
# Get expected models for OpenAI from the config
if "openai" not in github_config.get("providers", {}):
@@ -207,17 +198,26 @@ def test_auto_mode_provider_gets_synced_from_github_config(
)
# Verify the models were synced
synced_model_names = {m["name"] for m in synced_provider["model_configurations"]}
synced_model_configs = synced_provider["model_configurations"]
synced_model_names = {m["name"] for m in synced_model_configs}
print(f"Synced models: {synced_model_names}")
assert expected_models.issubset(
synced_model_names
), f"Expected models {expected_models} not found in synced models {synced_model_names}"
# Verify the outdated model was removed
# Verify the outdated model still exists but is not visible
# (Auto mode marks removed models as not visible, it doesn't delete them)
outdated_model = next(
(m for m in synced_model_configs if m["name"] == "outdated-model-name"),
None,
)
assert (
"outdated-model-name" not in synced_model_names
), "Outdated model should have been removed by sync"
outdated_model is not None
), "Outdated model should still exist after sync (marked invisible, not deleted)"
assert not outdated_model[
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
@@ -230,6 +230,7 @@ def test_auto_mode_provider_gets_synced_from_github_config(
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None,
admin_user: DATestUser,
) -> None:
"""

View File

@@ -1,8 +1,17 @@
import time
from onyx.configs.constants import MessageType
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.conftest import DocumentBuilderType
TERMINATED_RESPONSE_MESSAGE = (
"Response was terminated prior to completion, try regenerating."
)
LOADING_RESPONSE_MESSAGE = "Message is loading... Please refresh the page soon."
def test_send_two_messages(basic_user: DATestUser) -> None:
# Create a chat session
@@ -104,3 +113,59 @@ def test_send_message__basic_searches(
# short doc should be more relevant and thus first
assert response.top_documents[0].document_id == short_doc.id
assert response.top_documents[1].document_id == long_doc.id
def test_send_message_disconnect_and_cleanup(
reset: None, admin_user: DATestUser
) -> None:
"""
Test that when a client disconnects mid-stream:
1. Client sends a message and disconnects after receiving just 1 packet
2. Client checks to see that their message ends up completed
Note: There is an interim period (between disconnect and checkup) where we expect
to see some sort of 'loading' message.
"""
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
# Send a message and disconnect after receiving just 1 packet
ChatSessionManager.send_message_with_disconnect(
chat_session_id=test_chat_session.id,
message="What are some important events that happened today?",
user_performing_action=admin_user,
disconnect_after_packets=1,
)
# Every 5 seconds, check if we have the latest state of the chat session up to a minute
increment_seconds = 1
max_seconds = 60
msg = TERMINATED_RESPONSE_MESSAGE
for _ in range(max_seconds // increment_seconds):
time.sleep(increment_seconds)
# Get the chat history
chat_history = ChatSessionManager.get_chat_history(
chat_session=test_chat_session,
user_performing_action=admin_user,
)
# Find the assistant message
assistant_message = None
for chat_obj in chat_history:
if chat_obj.message_type == MessageType.ASSISTANT:
assistant_message = chat_obj
break
assert assistant_message is not None, "Assistant message should exist"
msg = assistant_message.message
if msg != TERMINATED_RESPONSE_MESSAGE and msg != LOADING_RESPONSE_MESSAGE:
break
assert msg != TERMINATED_RESPONSE_MESSAGE and msg != LOADING_RESPONSE_MESSAGE, (
f"Assistant message should no longer be the terminated response message after cleanup, "
f"got: {msg}"
)

View File

@@ -71,10 +71,10 @@ class TestOnyxWebCrawler:
assert response.status_code == 200, response.text
data = response.json()
# Should return a result but with empty content
assert len(data["results"]) == 1
result = data["results"][0]
assert result["content"] == ""
assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value
# The API filters out docs with no title/content, so unreachable domains return no results
assert data["results"] == []
def test_handles_404_page(self, admin_user: DATestUser) -> None:
"""Test that the crawler handles 404 responses gracefully."""
@@ -86,8 +86,10 @@ class TestOnyxWebCrawler:
assert response.status_code == 200, response.text
data = response.json()
# Should return a result (possibly with empty content for 404)
assert len(data["results"]) == 1
assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value
# Non-200 responses are treated as non-content and filtered out
assert data["results"] == []
def test_https_url_with_path(self, admin_user: DATestUser) -> None:
"""Test that the crawler handles HTTPS URLs with paths correctly."""

View File

@@ -0,0 +1,309 @@
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_runner import _merge_tool_calls
def _make_tool_call(
tool_name: str,
tool_args: dict,
tool_call_id: str = "call_1",
turn_index: int = 0,
tab_index: int = 0,
) -> ToolCallKickoff:
"""Helper to create a ToolCallKickoff for testing."""
return ToolCallKickoff(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_args=tool_args,
placement=Placement(turn_index=turn_index, tab_index=tab_index),
)
class TestMergeToolCalls:
"""Tests for _merge_tool_calls function."""
def test_empty_list(self) -> None:
"""Empty input returns empty output."""
result = _merge_tool_calls([])
assert result == []
def test_single_search_tool_call_not_merged(self) -> None:
"""A single SearchTool call is returned as-is (no merging needed)."""
call = _make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["query1"]},
tool_call_id="call_1",
)
result = _merge_tool_calls([call])
assert len(result) == 1
assert result[0].tool_name == "internal_search"
assert result[0].tool_args == {"queries": ["query1"]}
assert result[0].tool_call_id == "call_1"
def test_single_web_search_tool_call_not_merged(self) -> None:
"""A single WebSearchTool call is returned as-is."""
call = _make_tool_call(
tool_name="web_search",
tool_args={"queries": ["web query"]},
)
result = _merge_tool_calls([call])
assert len(result) == 1
assert result[0].tool_name == "web_search"
assert result[0].tool_args == {"queries": ["web query"]}
def test_single_open_url_tool_call_not_merged(self) -> None:
"""A single OpenURLTool call is returned as-is."""
call = _make_tool_call(
tool_name="open_url",
tool_args={"urls": ["https://example.com"]},
)
result = _merge_tool_calls([call])
assert len(result) == 1
assert result[0].tool_name == "open_url"
assert result[0].tool_args == {"urls": ["https://example.com"]}
def test_multiple_search_tool_calls_merged(self) -> None:
"""Multiple SearchTool calls have their queries merged into one call."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["query1", "query2"]},
tool_call_id="call_1",
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["query3"]},
tool_call_id="call_2",
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_name == "internal_search"
assert result[0].tool_args["queries"] == ["query1", "query2", "query3"]
# Uses first call's ID
assert result[0].tool_call_id == "call_1"
def test_multiple_web_search_tool_calls_merged(self) -> None:
"""Multiple WebSearchTool calls have their queries merged."""
calls = [
_make_tool_call(
tool_name="web_search",
tool_args={"queries": ["web1"]},
tool_call_id="call_1",
),
_make_tool_call(
tool_name="web_search",
tool_args={"queries": ["web2", "web3"]},
tool_call_id="call_2",
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_name == "web_search"
assert result[0].tool_args["queries"] == ["web1", "web2", "web3"]
def test_multiple_open_url_tool_calls_merged(self) -> None:
"""Multiple OpenURLTool calls have their urls merged."""
calls = [
_make_tool_call(
tool_name="open_url",
tool_args={"urls": ["https://a.com"]},
tool_call_id="call_1",
),
_make_tool_call(
tool_name="open_url",
tool_args={"urls": ["https://b.com", "https://c.com"]},
tool_call_id="call_2",
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_name == "open_url"
assert result[0].tool_args["urls"] == [
"https://a.com",
"https://b.com",
"https://c.com",
]
def test_non_mergeable_tool_not_merged(self) -> None:
"""Non-mergeable tools (e.g., python) are returned as separate calls."""
calls = [
_make_tool_call(
tool_name="python",
tool_args={"code": "print(1)"},
tool_call_id="call_1",
),
_make_tool_call(
tool_name="python",
tool_args={"code": "print(2)"},
tool_call_id="call_2",
),
]
result = _merge_tool_calls(calls)
assert len(result) == 2
assert result[0].tool_args["code"] == "print(1)"
assert result[1].tool_args["code"] == "print(2)"
def test_mixed_mergeable_and_non_mergeable(self) -> None:
"""Mix of mergeable and non-mergeable tools handles correctly."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q1"]},
tool_call_id="search_1",
),
_make_tool_call(
tool_name="python",
tool_args={"code": "x = 1"},
tool_call_id="python_1",
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q2"]},
tool_call_id="search_2",
),
]
result = _merge_tool_calls(calls)
# Should have 2 calls: merged search + python
assert len(result) == 2
tool_names = {r.tool_name for r in result}
assert tool_names == {"internal_search", "python"}
search_result = next(r for r in result if r.tool_name == "internal_search")
assert search_result.tool_args["queries"] == ["q1", "q2"]
python_result = next(r for r in result if r.tool_name == "python")
assert python_result.tool_args["code"] == "x = 1"
def test_multiple_different_mergeable_tools(self) -> None:
"""Multiple different mergeable tools each get merged separately."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["search1"]},
),
_make_tool_call(
tool_name="web_search",
tool_args={"queries": ["web1"]},
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["search2"]},
),
_make_tool_call(
tool_name="web_search",
tool_args={"queries": ["web2"]},
),
]
result = _merge_tool_calls(calls)
# Should have 2 merged calls
assert len(result) == 2
search_result = next(r for r in result if r.tool_name == "internal_search")
assert search_result.tool_args["queries"] == ["search1", "search2"]
web_result = next(r for r in result if r.tool_name == "web_search")
assert web_result.tool_args["queries"] == ["web1", "web2"]
def test_preserves_first_call_placement(self) -> None:
"""Merged call uses the placement from the first call."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q1"]},
turn_index=1,
tab_index=2,
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q2"]},
turn_index=3,
tab_index=4,
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].placement.turn_index == 1
assert result[0].placement.tab_index == 2
def test_preserves_other_args_from_first_call(self) -> None:
"""Merged call preserves non-merge-field args from the first call."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q1"], "other_param": "value1"},
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q2"], "other_param": "value2"},
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_args["queries"] == ["q1", "q2"]
# Other params from first call are preserved
assert result[0].tool_args["other_param"] == "value1"
def test_handles_empty_queries_list(self) -> None:
"""Handles calls with empty queries lists."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": []},
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q1"]},
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_args["queries"] == ["q1"]
def test_handles_missing_merge_field(self) -> None:
"""Handles calls where the merge field is missing entirely."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={}, # No queries field
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q1"]},
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
assert result[0].tool_args["queries"] == ["q1"]
def test_handles_string_value_instead_of_list(self) -> None:
"""Handles edge case where merge field is a string instead of list."""
calls = [
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": "single_query"}, # String instead of list
),
_make_tool_call(
tool_name="internal_search",
tool_args={"queries": ["q2"]},
),
]
result = _merge_tool_calls(calls)
assert len(result) == 1
# String should be converted to list item
assert result[0].tool_args["queries"] == ["single_query", "q2"]

View File

@@ -140,6 +140,7 @@ module.exports = {
"**/src/**/codeUtils.test.ts",
"**/src/lib/**/*.test.ts",
"**/src/app/**/services/*.test.ts",
"**/src/refresh-components/**/*.test.ts",
// Add more patterns here as you add more unit tests
],
},

11
web/package-lock.json generated
View File

@@ -40,6 +40,7 @@
"@sentry/nextjs": "^10.22.0",
"@sentry/tracing": "^7.120.3",
"@stripe/stripe-js": "^4.6.0",
"@tailwindcss/container-queries": "^0.1.1",
"@tanstack/react-table": "^8.21.3",
"autoprefixer": "^10.4.22",
"class-variance-authority": "^0.7.0",
@@ -5875,6 +5876,15 @@
"tslib": "^2.8.0"
}
},
"node_modules/@tailwindcss/container-queries": {
"version": "0.1.1",
"resolved": "https://registry.npmjs.org/@tailwindcss/container-queries/-/container-queries-0.1.1.tgz",
"integrity": "sha512-p18dswChx6WnTSaJCSGx6lTmrGzNNvm2FtXmiO6AuA1V4U5REyoqwmT6kgAsIMdjo07QdAfYXHJ4hnMtfHzWgA==",
"license": "MIT",
"peerDependencies": {
"tailwindcss": ">=3.2.0"
}
},
"node_modules/@tailwindcss/typography": {
"version": "0.5.19",
"dev": true,
@@ -10298,6 +10308,7 @@
},
"node_modules/fsevents": {
"version": "2.3.2",
"dev": true,
"license": "MIT",
"optional": true,
"os": [

View File

@@ -56,6 +56,7 @@
"@sentry/nextjs": "^10.22.0",
"@sentry/tracing": "^7.120.3",
"@stripe/stripe-js": "^4.6.0",
"@tailwindcss/container-queries": "^0.1.1",
"@tanstack/react-table": "^8.21.3",
"autoprefixer": "^10.4.22",
"class-variance-authority": "^0.7.0",

View File

@@ -83,11 +83,7 @@ export default function OnyxApiKeyForm({
can be added or changed later!
</Text>
<TextFormField
name="name"
label="Name (optional):"
autoCompleteDisabled={true}
/>
<TextFormField name="name" label="Name (optional):" />
<SelectorFormField
// defaultValue is managed by Formik

View File

@@ -294,7 +294,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
}
isSubmitting={isSubmitting}
>
<Form className="flex flex-col gap-0 bg-background-tint-01">
<Form className="flex flex-col gap-0 bg-background-tint-01 w-full">
<div className="flex flex-col gap-4 w-full">
{children(childProps)}
</div>

View File

@@ -6,7 +6,7 @@ import {
ProviderFormContext,
} from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { ApiKeyField } from "./components/ApiKeyField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
buildDefaultInitialValues,
@@ -94,7 +94,7 @@ export function AnthropicForm({
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<ApiKeyField />
<PasswordInputTypeInField name="api_key" label="API Key" />
<DisplayModels
modelConfigurations={modelConfigurations}

View File

@@ -7,7 +7,7 @@ import {
ProviderFormContext,
} from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { ApiKeyField } from "./components/ApiKeyField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
buildDefaultInitialValues,
@@ -140,7 +140,7 @@ export function AzureForm({
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<ApiKeyField />
<PasswordInputTypeInField name="api_key" label="API Key" />
<TextFormField
name="target_uri"

View File

@@ -1,6 +1,7 @@
import { useState, useEffect } from "react";
import { Form, Formik, FormikProps } from "formik";
import { SelectorFormField, TextFormField } from "@/components/Field";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderView,
@@ -193,11 +194,10 @@ function BedrockFormInternals({
label="AWS Access Key ID"
placeholder="AKIAIOSFODNN7EXAMPLE"
/>
<TextFormField
<PasswordInputTypeInField
name={FIELD_AWS_SECRET_ACCESS_KEY}
label="AWS Secret Access Key"
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
type="password"
/>
</div>
</TabsContent>
@@ -210,11 +210,10 @@ function BedrockFormInternals({
)}
>
<div className="flex flex-col gap-4">
<TextFormField
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
label="AWS Bedrock Long-term API Key"
placeholder="Your long-term API key"
type="password"
/>
</div>
</TabsContent>

View File

@@ -11,7 +11,7 @@ import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { ApiKeyField } from "./components/ApiKeyField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
submitLLMProvider,
@@ -190,7 +190,10 @@ export function CustomForm({
determine which fields are required.
</Text>
<ApiKeyField label="[Optional] API Key" />
<PasswordInputTypeInField
name="api_key"
label="[Optional] API Key"
/>
<TextFormField
name="api_base"

View File

@@ -1,5 +1,6 @@
import { Form, Formik, FormikProps } from "formik";
import { TextFormField } from "@/components/Field";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderView,
@@ -91,13 +92,10 @@ function OllamaFormContent({
placeholder={DEFAULT_API_BASE}
/>
<TextFormField
<PasswordInputTypeInField
name="custom_config.OLLAMA_API_KEY"
label="API Key (Optional)"
subtext="Optional API key for Ollama Cloud (https://ollama.com). Leave blank for local instances."
placeholder=""
type="password"
showPasswordToggle
/>
<DisplayModels

View File

@@ -4,7 +4,7 @@ import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { ApiKeyField } from "./components/ApiKeyField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
buildDefaultInitialValues,
@@ -91,7 +91,7 @@ export function OpenAIForm({
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<ApiKeyField />
<PasswordInputTypeInField name="api_key" label="API Key" />
<DisplayModels
modelConfigurations={modelConfigurations}

View File

@@ -12,7 +12,7 @@ import {
ProviderFormContext,
} from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { ApiKeyField } from "./components/ApiKeyField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
buildDefaultInitialValues,
@@ -172,7 +172,7 @@ export function OpenRouterForm({
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<ApiKeyField />
<PasswordInputTypeInField name="api_key" label="API Key" />
<TextFormField
name="api_base"

View File

@@ -1,12 +0,0 @@
import { TextFormField } from "@/components/Field";
export function ApiKeyField({ label }: { label?: string }) {
return (
<TextFormField
name="api_key"
label={label || "API Key"}
placeholder="API Key"
type="password"
/>
);
}

View File

@@ -120,9 +120,7 @@ export function ProviderFormEntrypointWrapper({
title={`Setup ${providerName}`}
onClose={onClose}
/>
<Modal.Body className="max-h-[70vh] overflow-y-auto">
{children(context)}
</Modal.Body>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
)}
@@ -208,9 +206,7 @@ export function ProviderFormEntrypointWrapper({
}`}
onClose={onClose}
/>
<Modal.Body className="max-h-[70vh] overflow-y-auto">
{children(context)}
</Modal.Body>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
)}

View File

@@ -130,20 +130,20 @@ export const filterModelConfigurations = (
};
// Helper to get model configurations for auto mode
// In auto mode, we include ALL models but preserve their visibility status
// Models in the auto config are visible, others are created but not visible
export const getAutoModeModelConfigurations = (
modelConfigurations: ModelConfiguration[]
): ModelConfiguration[] => {
return modelConfigurations
.filter((m) => m.is_visible)
.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: true,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
display_name: modelConfiguration.display_name,
})
);
return modelConfigurations.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: modelConfiguration.is_visible,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
display_name: modelConfiguration.display_name,
})
);
};
export const submitLLMProvider = async <T extends BaseLLMFormValues>({

View File

@@ -142,7 +142,7 @@ export default function UpgradingPage({
previous model and all progress will be lost.
</div>
</Modal.Body>
<Modal.Footer className="p-4 flex gap-x-2 w-full justify-end">
<Modal.Footer>
<Button onClick={onCancel}>Confirm</Button>
<Button onClick={() => setIsCancelling(false)} secondary>
Cancel

View File

@@ -19,6 +19,11 @@ export type WebProviderSetupModalProps = {
description: string;
apiKeyValue: string;
onApiKeyChange: (value: string) => void;
/**
* When true, the API key is a stored/masked value from the backend
* that cannot actually be revealed. The reveal toggle will be disabled.
*/
isStoredApiKey?: boolean;
optionalField?: {
label: string;
value: string;
@@ -45,6 +50,7 @@ export const WebProviderSetupModal = memo(
description,
apiKeyValue,
onApiKeyChange,
isStoredApiKey = false,
optionalField,
helperMessage,
helperClass,
@@ -126,8 +132,9 @@ export const WebProviderSetupModal = memo(
placeholder="Enter API key"
value={apiKeyValue}
autoFocus={apiKeyAutoFocus}
isNonRevealable={isStoredApiKey}
onFocus={(e) => {
if (apiKeyValue === "••••••••••••••••") {
if (isStoredApiKey) {
e.target.select();
}
}}
@@ -239,7 +246,7 @@ export const WebProviderSetupModal = memo(
</FormField>
)}
</Modal.Body>
<Modal.Footer className="gap-2">
<Modal.Footer>
<Button type="button" main secondary onClick={onClose}>
Cancel
</Button>

View File

@@ -935,7 +935,6 @@ export default function Page() {
provider
);
}}
className="h-6 w-6 opacity-70 hover:opacity-100"
aria-label={`Edit ${label}`}
/>
)}
@@ -1134,7 +1133,6 @@ export default function Page() {
provider
);
}}
className="h-6 w-6 opacity-70 hover:opacity-100"
aria-label={`Edit ${label}`}
/>
)}
@@ -1211,6 +1209,7 @@ export default function Page() {
onApiKeyChange={(value) =>
dispatchSearchModal({ type: "SET_API_KEY", value })
}
isStoredApiKey={searchModal.apiKeyValue === MASKED_API_KEY_PLACEHOLDER}
optionalField={
selectedProviderType === "google_pse"
? {
@@ -1336,6 +1335,7 @@ export default function Page() {
onApiKeyChange={(value) =>
dispatchContentModal({ type: "SET_API_KEY", value })
}
isStoredApiKey={contentModal.apiKeyValue === MASKED_API_KEY_PLACEHOLDER}
optionalField={
selectedContentProviderType === "firecrawl"
? {

View File

@@ -136,7 +136,7 @@ export default function IndexAttemptErrorsModal({
}
onClose={onClose}
/>
<Modal.Body className="flex flex-col gap-4 min-h-0">
<Modal.Body>
{!isResolvingErrors && (
<div className="flex flex-col gap-2 flex-shrink-0">
<Text as="p">

View File

@@ -37,7 +37,7 @@ interface InlineFileManagementProps {
onRefresh: () => void;
}
export function InlineFileManagement({
export default function InlineFileManagement({
connectorId,
onRefresh,
}: InlineFileManagementProps) {
@@ -360,7 +360,7 @@ export function InlineFileManagement({
description="When you save these changes, the following will happen:"
/>
<Modal.Body className="px-6 space-y-3">
<Modal.Body>
{selectedFilesToRemove.size > 0 && (
<div className="p-3 bg-red-50 dark:bg-red-900/10 rounded-md">
<Text
@@ -402,7 +402,7 @@ export function InlineFileManagement({
)}
</Modal.Body>
<Modal.Footer className="p-6 pt-4">
<Modal.Footer>
<Button
onClick={() => setShowSaveConfirm(false)}
secondary
@@ -410,7 +410,7 @@ export function InlineFileManagement({
>
Cancel
</Button>
<Button onClick={handleConfirmSave} primary disabled={isSaving}>
<Button onClick={handleConfirmSave} disabled={isSaving}>
{isSaving ? "Saving..." : "Confirm & Save"}
</Button>
</Modal.Footer>

View File

@@ -25,7 +25,7 @@ import {
} from "./ConfigDisplay";
import DeletionErrorStatus from "./DeletionErrorStatus";
import { IndexAttemptsTable } from "./IndexAttemptsTable";
import { InlineFileManagement } from "./InlineFileManagement";
import InlineFileManagement from "./InlineFileManagement";
import { buildCCPairInfoUrl, triggerIndexing } from "./lib";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import {

View File

@@ -178,13 +178,11 @@ export const DocumentSetCreationForm = ({
name="name"
label="Name:"
placeholder="A name for the document set"
autoCompleteDisabled={true}
/>
<TextFormField
name="description"
label="Description:"
placeholder="Describe what the document set represents"
autoCompleteDisabled={true}
optional={true}
/>

View File

@@ -339,7 +339,7 @@ const RerankingDetailsForm = forwardRef<
better performance.
</p>
</Modal.Body>
<Modal.Footer className="p-4 flex justify-end">
<Modal.Footer>
<Button
onClick={() => setShowGpuWarningModalModel(null)}
>
@@ -433,7 +433,7 @@ const RerankingDetailsForm = forwardRef<
/>
</div>
</Modal.Body>
<Modal.Footer className="p-4 flex w-full justify-end">
<Modal.Footer>
<Button
onClick={() => {
setShowLiteLLMConfigurationModal(false);
@@ -513,7 +513,7 @@ const RerankingDetailsForm = forwardRef<
/>
</div>
</Modal.Body>
<Modal.Footer className="p-4 flex w-full justify-end">
<Modal.Footer>
<Button onClick={() => setIsApiKeyModalOpen(false)}>
Update
</Button>

View File

@@ -29,7 +29,7 @@ export default function InstantSwitchConfirmModal({
<strong>This is not reversible.</strong>
</Text>
</Modal.Body>
<Modal.Footer className="p-4 gap-2">
<Modal.Footer>
<Button onClick={onConfirm}>Confirm</Button>
<Button secondary onClick={onClose}>
Cancel

View File

@@ -55,7 +55,7 @@ export default function ModelSelectionConfirmationModal({
</Callout>
)}
</Modal.Body>
<Modal.Footer className="p-4 gap-2 justify-end">
<Modal.Footer>
<Button onClick={onConfirm}>Confirm</Button>
<Button secondary onClick={onCancel}>
Cancel

View File

@@ -30,7 +30,7 @@ export default function SelectModelModal({
you will need to undergo a complete re-indexing. Are you sure?
</Text>
</Modal.Body>
<Modal.Footer className="p-4 gap-2 justify-end">
<Modal.Footer>
<Button onClick={onConfirm}>Confirm</Button>
<Button secondary onClick={onCancel}>
Cancel

View File

@@ -1,273 +0,0 @@
import { Form, Formik } from "formik";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { TextFormField } from "@/components/Field";
import Button from "@/refresh-components/buttons/Button";
import Separator from "@/refresh-components/Separator";
import { Callout } from "@/components/ui/callout";
import Text from "@/refresh-components/texts/Text";
import Modal from "@/refresh-components/Modal";
import { SvgKey } from "@opal/icons";
import {
OAuthConfig,
OAuthConfigCreate,
OAuthConfigUpdate,
} from "@/lib/tools/interfaces";
import { createOAuthConfig, updateOAuthConfig } from "@/lib/oauth/api";
import * as Yup from "yup";
interface OAuthConfigFormProps {
onClose: () => void;
setPopup: (popupSpec: PopupSpec | null) => void;
config?: OAuthConfig;
onConfigSubmitted?: (config: OAuthConfig) => void;
}
const OAuthConfigSchema = Yup.object().shape({
name: Yup.string().required("Name is required"),
authorization_url: Yup.string()
.url("Must be a valid URL")
.required("Authorization URL is required"),
token_url: Yup.string()
.url("Must be a valid URL")
.required("Token URL is required"),
client_id: Yup.string().when("isUpdate", {
is: false,
then: (schema) => schema.required("Client ID is required"),
otherwise: (schema) => schema,
}),
client_secret: Yup.string().when("isUpdate", {
is: false,
then: (schema) => schema.required("Client Secret is required"),
otherwise: (schema) => schema,
}),
scopes: Yup.string(),
});
export const OAuthConfigForm = ({
onClose,
setPopup,
config,
onConfigSubmitted,
}: OAuthConfigFormProps) => {
const isUpdate = config !== undefined;
return (
<Modal
open
onOpenChange={(open) => {
if (!open) {
onClose();
}
}}
>
<Modal.Content medium className="w-[60%] max-h-[80vh]">
<Modal.Header
icon={SvgKey}
title={
isUpdate
? "Update OAuth Configuration"
: "Create OAuth Configuration"
}
onClose={onClose}
/>
<Formik
initialValues={{
name: config?.name || "",
authorization_url: config?.authorization_url || "",
token_url: config?.token_url || "",
client_id: "",
client_secret: "",
scopes: config?.scopes?.join(", ") || "",
isUpdate,
}}
validationSchema={OAuthConfigSchema}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
// Parse scopes from comma-separated string
const scopesArray = values.scopes
.split(",")
.map((s) => s.trim())
.filter((s) => s.length > 0);
if (isUpdate && config) {
// Build update payload - only include fields that are provided
const updatePayload: OAuthConfigUpdate = {
name: values.name,
authorization_url: values.authorization_url,
token_url: values.token_url,
scopes: scopesArray,
};
// Only include client credentials if they are provided
if (values.client_id) {
updatePayload.client_id = values.client_id;
}
if (values.client_secret) {
updatePayload.client_secret = values.client_secret;
}
const updatedConfig = await updateOAuthConfig(
config.id,
updatePayload
);
setPopup({
message: "Successfully updated OAuth configuration!",
type: "success",
});
// Call the callback to refresh the list
if (onConfigSubmitted) {
onConfigSubmitted(updatedConfig);
}
} else {
// Create new config
const createPayload: OAuthConfigCreate = {
name: values.name,
authorization_url: values.authorization_url,
token_url: values.token_url,
client_id: values.client_id,
client_secret: values.client_secret,
scopes: scopesArray,
};
const createdConfig = await createOAuthConfig(createPayload);
setPopup({
message: "Successfully created OAuth configuration!",
type: "success",
});
// Call the callback with the created config
if (onConfigSubmitted && createdConfig) {
onConfigSubmitted(createdConfig);
}
}
onClose();
} catch (error: any) {
setPopup({
message: isUpdate
? `Error updating OAuth configuration - ${error.message}`
: `Error creating OAuth configuration - ${error.message}`,
type: "error",
});
} finally {
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting }) => (
<Form className="w-full overflow-visible">
<Modal.Body className="overflow-y-auto px-6 w-full">
<Separator noPadding />
<Text>
Configure an OAuth provider that can be shared across multiple
custom tools. Users will authenticate with this provider when
using tools that require it.
</Text>
<Callout
type="notice"
icon="📋"
title="Redirect URI for OAuth App Configuration"
className="my-0"
>
<Text as="p" className="text-sm mb-2">
When configuring your OAuth application in the
provider&apos;s dashboard, use this redirect URI:
</Text>
<code className="block p-2 bg-background-100 rounded text-sm font-mono">
{typeof window !== "undefined"
? `${window.location.origin}/oauth-config/callback`
: "{YOUR_DOMAIN}/oauth-config/callback"}
</code>
</Callout>
<TextFormField
name="name"
label="Configuration Name:"
subtext="A friendly name to identify this OAuth configuration (e.g., 'GitHub OAuth', 'Google OAuth')"
placeholder="e.g., GitHub OAuth"
autoCompleteDisabled={true}
/>
<TextFormField
name="authorization_url"
label="Authorization URL:"
subtext="The OAuth provider's authorization endpoint"
placeholder="e.g., https://github.com/login/oauth/authorize"
autoCompleteDisabled={true}
/>
<TextFormField
name="token_url"
label="Token URL:"
subtext="The OAuth provider's token exchange endpoint"
placeholder="e.g., https://github.com/login/oauth/access_token"
autoCompleteDisabled={true}
/>
<TextFormField
name="client_id"
label={isUpdate ? "Client ID (optional):" : "Client ID:"}
subtext={
isUpdate
? "Leave empty to keep existing client ID"
: "Your OAuth application's client ID"
}
placeholder={
isUpdate
? "Enter new client ID to update"
: "Your client ID"
}
autoCompleteDisabled={true}
/>
<TextFormField
name="client_secret"
label={
isUpdate ? "Client Secret (optional):" : "Client Secret:"
}
subtext={
isUpdate
? "Leave empty to keep existing client secret"
: "Your OAuth application's client secret"
}
placeholder={
isUpdate
? "Enter new client secret to update"
: "Your client secret"
}
type="password"
autoCompleteDisabled={true}
/>
<TextFormField
name="scopes"
label="Scopes (optional):"
subtext="Comma-separated list of OAuth scopes to request (e.g., 'repo, user')"
placeholder="e.g., repo, user"
autoCompleteDisabled={true}
/>
</Modal.Body>
<Modal.Footer className="w-full">
<Button
type="button"
onClick={onClose}
disabled={isSubmitting}
secondary
>
Cancel
</Button>
<Button type="submit" disabled={isSubmitting} primary>
{isUpdate ? "Update" : "Create"}
</Button>
</Modal.Footer>
</Form>
)}
</Formik>
</Modal.Content>
</Modal>
);
};

View File

@@ -5,10 +5,14 @@ import { useContext, useState } from "react";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import Button from "@/refresh-components/buttons/Button";
import { Input } from "@/components/ui/input";
import { ThreeDotsLoader } from "@/components/Loading";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { SvgCopy } from "@opal/icons";
import { Card } from "@/refresh-components/cards";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import * as GeneralLayouts from "@/layouts/general-layouts";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
export function AnonymousUserPath({
setPopup,
}: {
@@ -38,15 +42,8 @@ export function AnonymousUserPath({
async function handleCustomPathUpdate() {
try {
if (!customPath) {
setPopup({
message: "Custom path cannot be empty",
type: "error",
});
return;
}
// Validate custom path
if (!customPath.trim()) {
if (!customPath || !customPath.trim()) {
setPopup({
message: "Custom path cannot be empty",
type: "error",
@@ -95,56 +92,50 @@ export function AnonymousUserPath({
}
return (
<div className="mt-4 ml-6 max-w-xl p-6 bg-white shadow-lg border border-background-200 rounded-lg">
<h4 className="font-semibold text-lg text-text-800 mb-3">
Anonymous User Access
</h4>
<p className="text-text-600 text-sm mb-4">
Enable this to allow non-authenticated users to access all documents
indexed by public connectors in your workspace.
{anonymousUserPath
? "Customize the access path for anonymous users."
: "Set a custom access path for anonymous users."}{" "}
Anonymous users will only be able to view and search public documents,
but cannot access private or restricted content. The path will always
start with &quot;/anonymous/&quot;.
</p>
{isLoading ? (
<ThreeDotsLoader />
) : (
<div className="flex flex-col gap-2 justify-center items-start">
<div className="w-full flex-grow flex items-center rounded-md shadow-sm">
<span className="inline-flex items-center rounded-l-md border border-r-0 border-background-300 bg-background-50 px-3 text-text-500 sm:text-sm h-10">
{settings?.webDomain}/anonymous/
</span>
<Input
type="text"
className="block w-full flex-grow flex-1 rounded-none rounded-r-md border-background-300 focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm h-10"
placeholder="your-custom-path"
value={customPath ?? anonymousUserPath ?? ""}
onChange={(e) => setCustomPath(e.target.value)}
/>
</div>
<div className="flex flex-row gap-2">
<Button onClick={handleCustomPathUpdate}>Update Path</Button>
<Button
secondary
onClick={() => {
navigator.clipboard.writeText(
`${settings?.webDomain}/anonymous/${anonymousUserPath}`
);
setPopup({
message: "Invite link copied!",
type: "success",
});
}}
leftIcon={SvgCopy}
<div className="max-w-xl">
<Card gap={0}>
<GeneralLayouts.Section alignItems="start" gap={0.5}>
<Text headingH3>Anonymous User Access</Text>
<Text secondaryBody text03>
Enable this to allow anonymous users to access all public connectors
in your workspace. Anonymous users will not be able to access
private or restricted content.
</Text>
</GeneralLayouts.Section>
{isLoading ? (
<SimpleLoader className="self-center animate-spin mt-4" />
) : (
<>
<GeneralLayouts.Section flexDirection="row" gap={0.5}>
<Text mainContentBody text03>
{settings?.webDomain}/anonymous/
</Text>
<InputTypeIn
placeholder="your-custom-path"
value={customPath ?? anonymousUserPath ?? ""}
onChange={(e) => setCustomPath(e.target.value)}
showClearButton={false}
/>
</GeneralLayouts.Section>
<GeneralLayouts.Section
flexDirection="row"
gap={0.5}
justifyContent="start"
>
Copy
</Button>
</div>
</div>
)}
<Button onClick={handleCustomPathUpdate}>Update Path</Button>
<CopyIconButton
getCopyText={() =>
`${settings?.webDomain}/anonymous/${anonymousUserPath ?? ""}`
}
tooltip="Copy invite link"
secondary
/>
</GeneralLayouts.Section>
</>
)}
</Card>
</div>
);
}

View File

@@ -320,7 +320,7 @@ export function SettingsForm() {
anyone to use Onyx without signing in.
</p>
</Modal.Body>
<Modal.Footer className="p-4 flex justify-end gap-2">
<Modal.Footer>
<Button secondary onClick={() => setShowConfirmModal(false)}>
Cancel
</Button>

View File

@@ -45,6 +45,8 @@ export enum NotificationType {
PERSONA_SHARED = "persona_shared",
REINDEX = "reindex",
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending",
ASSISTANT_FILES_READY = "assistant_files_ready",
RELEASE_NOTES = "release_notes",
}
export interface Notification {
@@ -53,9 +55,12 @@ export interface Notification {
title: string;
description: string | null;
dismissed: boolean;
first_shown: string;
last_shown: string;
additional_data?: {
persona_id?: number;
link?: string;
version?: string; // For release notes notifications
[key: string]: any;
};
}

View File

@@ -44,7 +44,7 @@ describe("Email/Password Login Workflow", () => {
// User fills out the form using placeholder text
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
const passwordInput = screen.getByPlaceholderText(//);
await user.type(emailInput, "test@example.com");
await user.type(passwordInput, "password123");
@@ -90,7 +90,7 @@ describe("Email/Password Login Workflow", () => {
// User fills out form with invalid credentials
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
const passwordInput = screen.getByPlaceholderText(//);
await user.type(emailInput, "wrong@example.com");
await user.type(passwordInput, "wrongpassword");
@@ -142,7 +142,7 @@ describe("Email/Password Signup Workflow", () => {
// User fills out the signup form
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
const passwordInput = screen.getByPlaceholderText(//);
await user.type(emailInput, "newuser@example.com");
await user.type(passwordInput, "securepassword123");
@@ -208,7 +208,7 @@ describe("Email/Password Signup Workflow", () => {
// User fills out form with existing email
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
const passwordInput = screen.getByPlaceholderText(//);
await user.type(emailInput, "existing@example.com");
await user.type(passwordInput, "password123");
@@ -243,7 +243,7 @@ describe("Email/Password Signup Workflow", () => {
// User fills out form
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
const passwordInput = screen.getByPlaceholderText(//);
await user.type(emailInput, "user@example.com");
await user.type(passwordInput, "password123");

View File

@@ -224,7 +224,7 @@ export default function EmailPasswordForm({
}
field.onChange(e);
}}
placeholder="**************"
placeholder=""
onClear={() => helper.setValue("")}
data-testid="password"
error={apiStatus === "error"}

View File

@@ -204,6 +204,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
hasAnyProvider: llmManager.hasAnyProvider,
isLoadingChatSessions,
chatSessionsCount: chatSessions.length,
userId: user?.id,
});
const noAssistants = liveAssistant === null || liveAssistant === undefined;
@@ -672,6 +673,7 @@ export default function ChatPage({ firstMessage }: ChatPageProps) {
currentProjectId === null && (
<OnboardingFlow
handleHideOnboarding={hideOnboarding}
handleFinishOnboarding={finishOnboarding}
state={onboardingState}
actions={onboardingActions}
llmDescriptors={llmDescriptors}

View File

@@ -79,7 +79,7 @@ export function ChatPopup() {
icon={headerIcon}
title={popupTitle || "Welcome to Onyx!"}
/>
<Modal.Body className="bg-background-tint-01 py-4">
<Modal.Body>
<div className="overflow-y-auto text-left">
<ReactMarkdown
className="prose prose-neutral dark:prose-invert max-w-full"

View File

@@ -1,10 +1,7 @@
"use client";
import Logo from "@/refresh-components/Logo";
import {
GREETING_MESSAGES,
getRandomGreeting,
} from "@/lib/chat/greetingMessages";
import { getRandomGreeting } from "@/lib/chat/greetingMessages";
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
import Text from "@/refresh-components/texts/Text";
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
@@ -42,7 +39,7 @@ export default function WelcomeMessage({
);
} else if (agent) {
content = (
<div className="flex flex-col items-center gap-3 w-full max-w-[50rem]">
<>
<div
data-testid="assistant-name-display"
className="flex flex-row items-center gap-3"
@@ -57,7 +54,7 @@ export default function WelcomeMessage({
{agent.description}
</Text>
)}
</div>
</>
);
}
@@ -68,7 +65,7 @@ export default function WelcomeMessage({
return (
<div
data-testid="chat-intro"
className="flex flex-col items-center justify-center"
className="flex flex-col items-center justify-center gap-3 max-w-[50rem]"
>
{content}
</div>

View File

@@ -210,6 +210,12 @@ const ChatInputBar = React.memo(
}
}, [message]);
useEffect(() => {
if (initialMessage) {
setMessage(initialMessage);
}
}, [initialMessage]);
// Detect height changes and notify parent for scroll adjustment
useEffect(() => {
if (!containerRef.current) return;

View File

@@ -1,241 +0,0 @@
import React, { useEffect } from "react";
import { FiPlusCircle } from "react-icons/fi";
import { ChatInputOption } from "./ChatInputOption";
import { FilterManager } from "@/lib/hooks";
import { ChatFileType, FileDescriptor } from "@/app/chat/interfaces";
import {
InputBarPreview,
InputBarPreviewImageProvider,
} from "@/app/chat/components/files/InputBarPreview";
import { HorizontalSourceSelector } from "@/components/search/filtering/HorizontalSourceSelector";
import { Tag } from "@/lib/types";
import IconButton from "@/refresh-components/buttons/IconButton";
import { SvgArrowUp } from "@opal/icons";
const MAX_INPUT_HEIGHT = 200;
export interface ChatInputBarProps {
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
files: FileDescriptor[];
setFiles: (files: FileDescriptor[]) => void;
handleFileUpload: (files: File[]) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement | null>;
filterManager?: FilterManager;
existingSources: string[];
availableDocumentSets: { name: string }[];
availableTags: Tag[];
}
export default function SimplifiedChatInputBar({
message,
setMessage,
onSubmit,
files,
setFiles,
handleFileUpload,
textAreaRef,
filterManager,
existingSources,
availableDocumentSets,
availableTags,
}: ChatInputBarProps) {
useEffect(() => {
const textarea = textAreaRef.current;
if (textarea) {
textarea.style.height = "0px";
textarea.style.height = `${Math.min(
textarea.scrollHeight,
MAX_INPUT_HEIGHT
)}px`;
}
}, [message, textAreaRef]);
const handlePaste = (event: React.ClipboardEvent) => {
const items = event.clipboardData?.items;
if (items) {
const pastedFiles = [];
for (let i = 0; i < items.length; i++) {
const item = items[i];
if (item && item.kind === "file") {
const file = item.getAsFile();
if (file) pastedFiles.push(file);
}
}
if (pastedFiles.length > 0) {
event.preventDefault();
handleFileUpload(pastedFiles);
}
}
};
const handleInputChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const text = event.target.value;
setMessage(text);
};
return (
<div
id="onyx-chat-input"
className="
w-full
relative
mx-auto
"
>
<div
className="
opacity-100
w-full
h-fit
flex
flex-col
border
border-background-200
rounded-lg
relative
text-text-chatbar
bg-white
[&:has(textarea:focus)]::ring-1
[&:has(textarea:focus)]::ring-black
"
>
{files.length > 0 && (
<div className="flex gap-x-2 px-2 pt-2">
<div className="flex gap-x-1 px-2 overflow-visible overflow-x-scroll items-end miniscroll">
{files.map((file) => (
<div className="flex-none" key={file.id}>
{file.type === ChatFileType.IMAGE ? (
<InputBarPreviewImageProvider
file={file}
onDelete={() => {
setFiles(
files.filter(
(fileInFilter) => fileInFilter.id !== file.id
)
);
}}
isUploading={file.isUploading || false}
/>
) : (
<InputBarPreview
file={file}
onDelete={() => {
setFiles(
files.filter(
(fileInFilter) => fileInFilter.id !== file.id
)
);
}}
isUploading={file.isUploading || false}
/>
)}
</div>
))}
</div>
</div>
)}
<textarea
onPaste={handlePaste}
onChange={handleInputChange}
ref={textAreaRef}
className={`
m-0
w-full
shrink
resize-none
rounded-lg
border-0
bg-white
placeholder:text-text-chatbar-subtle
${
textAreaRef.current &&
textAreaRef.current.scrollHeight > MAX_INPUT_HEIGHT
? "overflow-y-auto mt-2"
: ""
}
whitespace-pre-wrap
break-word
overscroll-contain
outline-none
placeholder-subtle
resize-none
px-5
py-4
h-14
`}
autoFocus
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder="Ask me anything..."
value={message}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!event.shiftKey &&
!(event.nativeEvent as any).isComposing
) {
event.preventDefault();
if (message) {
onSubmit();
}
}
}}
suppressContentEditableWarning={true}
/>
<div className="flex items-center space-x-3 mr-12 px-4 pb-2">
<ChatInputOption
flexPriority="stiff"
name="File"
Icon={FiPlusCircle}
onClick={() => {
const input = document.createElement("input");
input.type = "file";
input.multiple = true; // Allow multiple files
input.onchange = (event: any) => {
const selectedFiles = Array.from(
event?.target?.files || []
) as File[];
if (selectedFiles.length > 0) {
handleFileUpload(selectedFiles);
}
};
input.click();
}}
/>
{filterManager && (
<HorizontalSourceSelector
timeRange={filterManager.timeRange}
setTimeRange={filterManager.setTimeRange}
selectedSources={filterManager.selectedSources}
setSelectedSources={filterManager.setSelectedSources}
selectedDocumentSets={filterManager.selectedDocumentSets}
setSelectedDocumentSets={filterManager.setSelectedDocumentSets}
selectedTags={filterManager.selectedTags}
setSelectedTags={filterManager.setSelectedTags}
existingSources={existingSources}
availableDocumentSets={availableDocumentSets}
availableTags={availableTags}
/>
)}
</div>
</div>
<div className="absolute bottom-2 mobile:right-4 desktop:right-4">
<IconButton
id="onyx-chat-input-send-button"
icon={SvgArrowUp}
onClick={() => {
if (message) {
onSubmit();
}
}}
disabled={!message}
/>
</div>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More