Compare commits

...

2 Commits

Author SHA1 Message Date
Yuhong Sun
3921d105dc ok 2025-12-04 14:29:59 -08:00
Yuhong Sun
008e0021f1 mypy 2025-12-04 13:39:15 -08:00
45 changed files with 995 additions and 3095 deletions

View File

@@ -15,8 +15,6 @@ from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.constants import MessageType
from onyx.context.search.models import OptionalSearchSetting
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
@@ -77,40 +75,17 @@ def handle_simplified_chat_message(
chat_session_id=chat_session_id, db_session=db_session
)
if (
chat_message_req.retrieval_options is None
and chat_message_req.search_doc_ids is None
):
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=chat_message_req.query_override,
# Currently only applies to search flow not chat
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return gather_stream(packets)
@@ -123,8 +98,7 @@ def handle_send_message_simple_with_history(
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information.
takes in chat history maintained by the caller
and does query rephrasing similar to answer-with-quote"""
takes in chat history maintained by the caller"""
if len(req.messages) == 0:
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
@@ -194,41 +168,22 @@ def handle_send_message_simple_with_history(
llm_tokenizer=llm_tokenizer,
)
rephrased_query = req.query_override or thread_based_query_rephrase(
rephrased_query = thread_based_query_rephrase(
user_query=query,
history_str=history_str,
)
if req.retrieval_options is None and req.search_doc_ids is None:
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=chat_message.id,
message=query,
file_descriptors=[],
search_doc_ids=req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=rephrased_query,
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
message=rephrased_query,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return gather_stream(packets)

View File

@@ -12,7 +12,6 @@ from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
@@ -43,20 +42,10 @@ class BasicCreateChatMessageRequest(ChunkContext):
persona_id: int | None = None
# New message contents
message: str
# Defaults to using retrieval with no additional filters
retrieval_options: RetrievalDetails | None = None
# Allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
@@ -68,16 +57,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
messages: list[ThreadMessage]
persona_id: int
retrieval_options: RetrievalDetails | None = None
query_override: str | None = None
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):

View File

@@ -4,11 +4,9 @@ from collections.abc import Callable
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from onyx.auth.users import is_user_admin
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
try_creating_kg_processing_task,
)
@@ -17,24 +15,19 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
)
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import BaseFilters
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.search_settings import get_current_search_settings
@@ -51,9 +44,6 @@ from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
@@ -64,15 +54,10 @@ logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
filters: BaseFilters | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
llm_override: LLMOverride | None = None,
allowed_tool_ids: list[int] | None = None,
) -> CreateChatMessageRequest:
@@ -91,15 +76,7 @@ def prepare_chat_message_request(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
filters=filters,
llm_override=llm_override,
allowed_tool_ids=allowed_tool_ids,
)
@@ -355,68 +332,69 @@ def extract_headers(
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
# TODO in case it needs to be referenced later
# def create_temporary_persona(
# persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
# ) -> Persona:
# if not is_user_admin(user):
# raise HTTPException(
# status_code=403,
# detail="User is not authorized to create a persona in one shot queries",
# )
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
# """Create a temporary Persona object from the provided configuration."""
# persona = Persona(
# name=persona_config.name,
# description=persona_config.description,
# num_chunks=persona_config.num_chunks,
# llm_relevance_filter=persona_config.llm_relevance_filter,
# llm_filter_extraction=persona_config.llm_filter_extraction,
# recency_bias=persona_config.recency_bias,
# llm_model_provider_override=persona_config.llm_model_provider_override,
# llm_model_version_override=persona_config.llm_model_version_override,
# )
if persona_config.prompts:
# Use the first prompt from the override config for embedded prompt fields
first_prompt = persona_config.prompts[0]
persona.system_prompt = first_prompt.system_prompt
persona.task_prompt = first_prompt.task_prompt
persona.datetime_aware = first_prompt.datetime_aware
# if persona_config.prompts:
# # Use the first prompt from the override config for embedded prompt fields
# first_prompt = persona_config.prompts[0]
# persona.system_prompt = first_prompt.system_prompt
# persona.task_prompt = first_prompt.task_prompt
# persona.datetime_aware = first_prompt.datetime_aware
persona.tools = []
if persona_config.custom_tools_openapi:
from onyx.chat.emitter import get_default_emitter
# persona.tools = []
# if persona_config.custom_tools_openapi:
# from onyx.chat.emitter import get_default_emitter
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=schema,
emitter=get_default_emitter(),
),
)
persona.tools.extend(tools)
# for schema in persona_config.custom_tools_openapi:
# tools = cast(
# list[Tool],
# build_custom_tools_from_openapi_schema_and_headers(
# tool_id=0, # dummy tool id
# openapi_schema=schema,
# emitter=get_default_emitter(),
# ),
# )
# persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
# if persona_config.tools:
# tool_ids = [tool.id for tool in persona_config.tools]
# persona.tools.extend(
# fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
# )
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
# if persona_config.tool_ids:
# persona.tools.extend(
# fetch_existing_tools(
# db_session=db_session, tool_ids=persona_config.tool_ids
# )
# )
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
# fetched_docs = fetch_existing_doc_sets(
# db_session=db_session, doc_ids=persona_config.document_set_ids
# )
# persona.document_sets = fetched_docs
return persona
# return persona
def process_kg_commands(

View File

@@ -5,12 +5,10 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.enums import SearchType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import FileDescriptor
@@ -127,35 +125,6 @@ class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
datetime_aware: bool = True
include_citations: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
# Note: prompt_ids removed - prompts are now embedded in personas
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo

View File

@@ -26,8 +26,6 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
from onyx.chat.save_chat import save_chat_turn
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.context.search.models import CitationDocInfo
@@ -260,17 +258,10 @@ def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
# Additional context that should be included in the chat history, for example:
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
@@ -298,10 +289,7 @@ def stream_chat_message_objects(
message_text = new_msg_req.message
chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
new_msg_req.alternate_assistant_id
user_selected_filters = retrieval_options.filters if retrieval_options else None
user_selected_filters = new_msg_req.filters
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
@@ -316,11 +304,6 @@ def stream_chat_message_objects(
db_session=db_session,
)
if reference_doc_ids is None and retrieval_options is None:
raise RuntimeError(
"Must specify a set of documents for chat or specify search options"
)
llm, fast_llm = get_llms_for_persona(
persona=persona,
user=user,

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm.session import SessionTransaction
from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.users import get_user_by_email
from onyx.evals.models import EvalationAck
@@ -73,15 +72,11 @@ def _get_answer(
request = prepare_chat_message_request(
message_text=eval_input["message"],
user=user,
filters=None,
persona_id=None,
persona_override_config=full_configuration.persona_override_config,
message_ts_to_respond_to=None,
retrieval_details=RetrievalDetails(),
rerank_settings=None,
db_session=db_session,
skip_gen_ai_answer_generation=False,
llm_override=full_configuration.llm,
use_agentic_search=False,
allowed_tool_ids=full_configuration.allowed_tool_ids,
)
packets = stream_chat_message_objects(

View File

@@ -6,9 +6,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import PromptOverrideConfig
from onyx.chat.models import ToolConfig
from onyx.db.tools import get_builtin_tool
from onyx.llm.override_models import LLMOverride
from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
@@ -16,7 +13,6 @@ from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
class EvalConfiguration(BaseModel):
builtin_tool_types: list[str] = Field(default_factory=list)
persona_override_config: PersonaOverrideConfig | None = None
llm: LLMOverride = Field(default_factory=LLMOverride)
search_permissions_email: str | None = None
allowed_tool_ids: list[int]
@@ -24,7 +20,6 @@ class EvalConfiguration(BaseModel):
class EvalConfigurationOptions(BaseModel):
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
persona_override_config: PersonaOverrideConfig | None = None
llm: LLMOverride = LLMOverride(
model_provider="Default",
model_version="gpt-4.1",
@@ -35,25 +30,7 @@ class EvalConfigurationOptions(BaseModel):
no_send_logs: bool = False
def get_configuration(self, db_session: Session) -> EvalConfiguration:
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
name="Eval",
description="A persona for evaluation",
tools=[
ToolConfig(id=get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id)
for tool in self.builtin_tool_types
],
prompts=[
PromptOverrideConfig(
name="Default",
description="Default prompt for evaluation",
system_prompt="You are a helpful assistant.",
task_prompt="",
datetime_aware=True,
)
],
)
return EvalConfiguration(
persona_override_config=persona_override_config,
llm=self.llm,
search_permissions_email=self.search_permissions_email,
allowed_tool_ids=[

View File

@@ -2,7 +2,6 @@ from collections.abc import Callable
from sqlalchemy.orm import Session
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -110,7 +109,7 @@ def get_llm_config_for_persona(
def get_llms_for_persona(
persona: Persona | PersonaOverrideConfig | None,
persona: Persona | None,
user: User | None,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
@@ -137,22 +136,18 @@ def get_llms_for_persona(
if not provider_model:
raise ValueError("No LLM provider found")
# Only check access control for database Persona entities, not PersonaOverrideConfig
# PersonaOverrideConfig is used for temporary overrides and doesn't have access restrictions
persona_model = persona if isinstance(persona, Persona) else None
# Fetch user group IDs for access control check
user_group_ids = fetch_user_group_ids(db_session, user)
if not can_user_access_llm_provider(
provider_model,
user_group_ids,
persona_model,
persona,
):
logger.warning(
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
getattr(user, "id", None),
getattr(persona_model, "id", None),
getattr(persona, "id", None),
provider_model.name,
)
return get_default_llms(

View File

@@ -113,9 +113,6 @@ from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from onyx.server.pat.api import router as pat_router
from onyx.server.query_and_chat.chat_backend import router as chat_router
from onyx.server.query_and_chat.chat_backend_v0 import router as chat_v0_router
@@ -411,9 +408,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)

View File

@@ -19,9 +19,7 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER
from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
@@ -210,31 +208,13 @@ def handle_regular_answer(
time_cutoff=None,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
# Commenting this out because this is only available to the slackbot for now
# later we plan to implement this at the persona level where this will get
# commented back in
# auto_detect_filters = (
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = slack_channel_config.enable_auto_filters
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
with get_session_with_current_tenant() as db_session:
answer_request = prepare_chat_message_request(
message_text=user_message.message,
user=user,
filters=filters,
persona_id=persona.id,
# This is not used in the Slack flow, only in the answer API
persona_override_config=None,
message_ts_to_respond_to=message_ts_to_respond_to,
retrieval_details=retrieval_details,
rerank_settings=None, # Rerank customization supported in Slack flow
db_session=db_session,
)

View File

@@ -15,6 +15,12 @@ from onyx.server.features.web_search.models import WebSearchToolResponse
from onyx.server.features.web_search.models import WebSearchWithContentResponse
from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.tools.models import (
LlmOpenUrlResult,
)
from onyx.tools.models import (
LlmWebSearchResult,
)
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
OnyxWebCrawler,
@@ -30,12 +36,6 @@ from onyx.tools.tool_implementations.web_search.providers import (
from onyx.tools.tool_implementations.web_search.utils import (
truncate_search_result_content,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmWebSearchResult,
)
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType

View File

@@ -2,10 +2,10 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.tools.tool_implementations_v2.tool_result_models import (
from onyx.tools.models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
from onyx.tools.models import (
LlmWebSearchResult,
)
from shared_configs.enums import WebContentProviderType

View File

@@ -1,253 +0,0 @@
from typing import Any
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_raw_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.tools import get_tool_by_name
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/assistants")
# Base models
class AssistantObject(BaseModel):
id: int
object: str = "assistant"
created_at: int
name: Optional[str] = None
description: Optional[str] = None
model: str
instructions: Optional[str] = None
tools: list[dict[str, Any]]
file_ids: list[str]
metadata: Optional[dict[str, Any]] = None
class CreateAssistantRequest(BaseModel):
model: str
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
icon_name: Optional[str] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class ModifyAssistantRequest(BaseModel):
model: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class DeleteAssistantResponse(BaseModel):
id: int
object: str = "assistant.deleted"
deleted: bool
class ListAssistantsResponse(BaseModel):
object: str = "list"
data: list[AssistantObject]
first_id: Optional[int] = None
last_id: Optional[int] = None
has_more: bool
def persona_to_assistant(persona: Persona) -> AssistantObject:
return AssistantObject(
id=persona.id,
created_at=0,
name=persona.name,
description=persona.description,
model=persona.llm_model_version_override or "gpt-3.5-turbo",
instructions=persona.system_prompt,
tools=[
{
"type": tool.display_name,
"function": {
"name": tool.name,
"description": tool.description,
"schema": tool.openapi_schema,
},
}
for tool in persona.tools
],
file_ids=[], # Assuming no file support for now
metadata={}, # Assuming no metadata for now
)
# API endpoints
@router.post("")
def create_assistant(
request: CreateAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
# No separate Prompt entity; instructions map to persona.system_prompt
tool_ids = []
for tool in request.tools or []:
tool_type = tool.get("type")
if not tool_type:
continue
try:
tool_db = get_tool_by_name(tool_type, db_session)
tool_ids.append(tool_db.id)
except ValueError:
# Skip tools that don't exist in the database
logger.error(f"Tool {tool_type} not found in database")
raise HTTPException(
status_code=404, detail=f"Tool {tool_type} not found in database"
)
persona = upsert_persona(
user=user,
name=request.name or f"Assistant-{uuid4()}",
description=request.description or "",
num_chunks=25,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=None,
llm_model_version_override=request.model,
starter_messages=None,
is_public=False,
db_session=db_session,
document_set_ids=[],
tool_ids=tool_ids,
icon_name=request.icon_name,
is_visible=True,
system_prompt=request.instructions or "",
task_prompt="",
datetime_aware=True,
)
return persona_to_assistant(persona)
@router.get("/{assistant_id}")
def retrieve_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
try:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
except ValueError:
persona = None
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
return persona_to_assistant(persona)
@router.post("/{assistant_id}")
def modify_assistant(
assistant_id: int,
request: ModifyAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=True,
)
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(persona, key, value)
if "instructions" in update_data:
persona.system_prompt = update_data["instructions"]
db_session.commit()
return persona_to_assistant(persona)
@router.delete("/{assistant_id}")
def delete_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DeleteAssistantResponse:
try:
mark_persona_as_deleted(
persona_id=int(assistant_id),
user=user,
db_session=db_session,
)
return DeleteAssistantResponse(id=assistant_id, deleted=True)
except ValueError:
raise HTTPException(status_code=404, detail="Assistant not found")
@router.get("")
def list_assistants(
limit: int = Query(20, le=100),
order: str = Query("desc", regex="^(asc|desc)$"),
after: Optional[int] = None,
before: Optional[int] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListAssistantsResponse:
persona_snapshots = list(
get_raw_personas_for_user(
user=user,
db_session=db_session,
get_editable=False,
)
)
# Apply filtering based on after and before
if after:
persona_snapshots = [p for p in persona_snapshots if p.id > int(after)]
if before:
persona_snapshots = [p for p in persona_snapshots if p.id < int(before)]
# Apply ordering
persona_snapshots.sort(key=lambda p: p.id, reverse=(order == "desc"))
# Apply limit
persona_snapshots = persona_snapshots[:limit]
assistants = [persona_to_assistant(p) for p in persona_snapshots]
return ListAssistantsResponse(
data=assistants,
first_id=assistants[0].id if assistants else None,
last_id=assistants[-1].id if assistants else None,
has_more=len(persona_snapshots) == limit,
)

View File

@@ -1,19 +0,0 @@
from fastapi import APIRouter
from onyx.server.openai_assistants_api.assistants_api import (
router as assistants_router,
)
from onyx.server.openai_assistants_api.messages_api import router as messages_router
from onyx.server.openai_assistants_api.runs_api import router as runs_router
from onyx.server.openai_assistants_api.threads_api import router as threads_router
def get_full_openai_assistants_api_router() -> APIRouter:
router = APIRouter(prefix="/openai-assistants")
router.include_router(assistants_router)
router.include_router(runs_router)
router.include_router(threads_router)
router.include_router(messages_router)
return router

View File

@@ -1,234 +0,0 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.configs.constants import MessageType
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.utils import check_number_of_tokens
router = APIRouter(prefix="")
Role = Literal["user", "assistant"]
class MessageContent(BaseModel):
type: Literal["text"]
text: str
class Message(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
object: Literal["thread.message"] = "thread.message"
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
thread_id: str
role: Role
content: list[MessageContent]
file_ids: list[str] = []
assistant_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
class CreateMessageRequest(BaseModel):
role: Role
content: str
file_ids: list[str] = []
metadata: Optional[dict] = None
class ListMessagesResponse(BaseModel):
object: Literal["list"] = "list"
data: list[Message]
first_id: str
last_id: str
has_more: bool
@router.post("/threads/{thread_id}/messages")
def create_message(
thread_id: str,
message: CreateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message=message.content,
token_count=check_number_of_tokens(message.content),
message_type=(
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
),
db_session=db_session,
)
return Message(
id=str(new_message.id),
thread_id=thread_id,
role="user",
content=[MessageContent(type="text", text=message.content)],
file_ids=message.file_ids,
metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages")
def list_messages(
thread_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListMessagesResponse:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
)
# Apply filtering based on after and before
if after:
messages = [m for m in messages if str(m.id) >= after]
if before:
messages = [m for m in messages if str(m.id) <= before]
# Apply ordering
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
# Apply limit
messages = messages[:limit]
data = [
Message(
id=str(m.id),
thread_id=thread_id,
role="user" if m.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=m.message)],
created_at=int(m.time_sent.timestamp()),
)
for m in messages
]
return ListMessagesResponse(
data=data,
first_id=str(data[0].id) if data else "",
last_id=str(data[-1].id) if data else "",
has_more=len(messages) == limit,
)
@router.get("/threads/{thread_id}/messages/{message_id}")
def retrieve_message(
thread_id: str,
message_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
)
class ModifyMessageRequest(BaseModel):
metadata: dict
@router.post("/threads/{thread_id}/messages/{message_id}")
def modify_message(
thread_id: str,
message_id: int,
request: ModifyMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
# Update metadata
# TODO: Uncomment this once we have metadata in the chat message
# chat_message.metadata = request.metadata
# db_session.commit()
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
metadata=request.metadata,
)

View File

@@ -1,344 +0,0 @@
from typing import Literal
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.constants import MessageType
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ChatMessage
from onyx.db.models import User
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter()
class RunRequest(BaseModel):
assistant_id: int
model: Optional[str] = None
instructions: Optional[str] = None
additional_instructions: Optional[str] = None
tools: Optional[list[dict]] = None
metadata: Optional[dict] = None
RunStatus = Literal[
"queued",
"in_progress",
"requires_action",
"cancelling",
"cancelled",
"failed",
"completed",
"expired",
]
class RunResponse(BaseModel):
id: str
object: Literal["thread.run"]
created_at: int
assistant_id: int
thread_id: UUID
status: RunStatus
started_at: Optional[int] = None
expires_at: Optional[int] = None
cancelled_at: Optional[int] = None
failed_at: Optional[int] = None
completed_at: Optional[int] = None
last_error: Optional[dict] = None
model: str
instructions: str
tools: list[dict]
file_ids: list[str]
metadata: Optional[dict] = None
def process_run_in_background(
message_id: int,
parent_message_id: int,
chat_session_id: UUID,
assistant_id: int,
instructions: str,
tools: list[dict],
user: User | None,
db_session: Session,
) -> None:
# Get the latest message in the chat session
_ = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user.id if user else None,
db_session=db_session,
)
search_tool_retrieval_details = RetrievalDetails()
for tool in tools:
if tool["type"] == SearchTool.__name__ and (
retrieval_details := tool.get("retrieval_details")
):
search_tool_retrieval_details = RetrievalDetails.model_validate(
retrieval_details
)
break
new_msg_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=int(parent_message_id) if parent_message_id else None,
message=instructions,
file_descriptors=[],
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
rerank_settings=None,
query_override=None,
regenerate=None,
llm_override=None,
prompt_override=None,
alternate_assistant_id=assistant_id,
use_existing_user_message=True,
existing_assistant_message_id=message_id,
)
run_message = get_chat_message(message_id, user.id if user else None, db_session)
try:
for packet in stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
):
if isinstance(packet, ChatMessageDetail):
# Update the run status and message content
run_message = get_chat_message(
message_id, user.id if user else None, db_session
)
if run_message:
# this handles cancelling
if run_message.error:
return
run_message.message = packet.message
run_message.message_type = MessageType.ASSISTANT
db_session.commit()
except Exception as e:
logger.exception("Error processing run in background")
run_message.error = str(e)
db_session.commit()
return
db_session.refresh(run_message)
if run_message.token_count == 0:
run_message.error = "No tokens generated"
db_session.commit()
@router.post("/threads/{thread_id}/runs")
def create_run(
thread_id: UUID,
run_request: RunRequest,
background_tasks: BackgroundTasks,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
# Create a new "run" (chat message) in the session
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message="",
token_count=0,
message_type=MessageType.ASSISTANT,
db_session=db_session,
commit=False,
)
db_session.flush()
latest_message.latest_child_message = new_message
db_session.commit()
# Schedule the background task
background_tasks.add_task(
process_run_in_background,
new_message.id,
latest_message.id,
chat_session.id,
run_request.assistant_id,
run_request.instructions or "",
run_request.tools or [],
user,
db_session,
)
return RunResponse(
id=str(new_message.id),
object="thread.run",
created_at=int(new_message.time_sent.timestamp()),
assistant_id=run_request.assistant_id,
thread_id=chat_session.id,
status="queued",
model=run_request.model or "default_model",
instructions=run_request.instructions or "",
tools=run_request.tools or [],
file_ids=[],
metadata=run_request.metadata,
)
@router.get("/threads/{thread_id}/runs/{run_id}")
def retrieve_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# Retrieve the chat message (which represents a "run" in DAnswer)
chat_message = get_chat_message(
chat_message_id=int(run_id), # Convert string run_id to int
user_id=user.id if user else None,
db_session=db_session,
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_session = chat_message.chat_session
# Map DAnswer status to OpenAI status
run_status: RunStatus = "queued"
if chat_message.message:
run_status = "in_progress"
if chat_message.token_count != 0:
run_status = "completed"
if chat_message.error:
run_status = "cancelled"
return RunResponse(
id=run_id,
object="thread.run",
created_at=int(chat_message.time_sent.timestamp()),
assistant_id=chat_session.persona_id or 0,
thread_id=chat_session.id,
status=run_status,
started_at=int(chat_message.time_sent.timestamp()),
completed_at=(
int(chat_message.time_sent.timestamp()) if chat_message.message else None
),
model=chat_session.current_alternate_model or "default_model",
instructions="", # DAnswer doesn't store per-message instructions
tools=[], # DAnswer doesn't have a direct equivalent for tools
file_ids=(
[file["id"] for file in chat_message.files] if chat_message.files else []
),
metadata=None, # DAnswer doesn't store metadata for individual messages
)
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
def cancel_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# In DAnswer, we don't have a direct equivalent to cancelling a run
# We'll simulate it by marking the message as "cancelled"
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_message.error = "Cancelled"
db_session.commit()
return retrieve_run(thread_id, run_id, user, db_session)
@router.get("/threads/{thread_id}/runs")
def list_runs(
thread_id: UUID,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[RunResponse]:
# In DAnswer, we'll treat each message in a chat session as a "run"
chat_messages = get_chat_messages_by_session(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
# Apply pagination
if after:
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
if before:
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
# Apply ordering
chat_messages = sorted(
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
)
# Apply limit
chat_messages = chat_messages[:limit]
return [
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
]
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
def list_run_steps(
thread_id: UUID,
run_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[dict]: # You may want to create a specific model for run steps
# DAnswer doesn't have an equivalent to run steps
# We'll return an empty list to maintain API compatibility
return []
# Additional helper functions can be added here if needed

View File

@@ -1,157 +0,0 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.chat import create_chat_session
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import update_chat_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter(prefix="/threads")
# Models
class Thread(BaseModel):
id: UUID
object: str = "thread"
created_at: int
metadata: Optional[dict[str, str]] = None
class CreateThreadRequest(BaseModel):
messages: Optional[list[dict]] = None
metadata: Optional[dict[str, str]] = None
class ModifyThreadRequest(BaseModel):
metadata: Optional[dict[str, str]] = None
# API Endpoints
@router.post("")
def create_thread(
request: CreateThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave the naming till later to prevent delay
user_id=user_id,
persona_id=0,
)
return Thread(
id=new_chat_session.id,
created_at=int(new_chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.get("/{thread_id}")
def retrieve_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=None, # Assuming we don't store metadata in our current implementation
)
@router.post("/{thread_id}")
def modify_thread(
thread_id: UUID,
request: ModifyThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = update_chat_session(
db_session=db_session,
user_id=user_id,
chat_session_id=thread_id,
description=None, # Not updating description
sharing_status=None, # Not updating sharing status
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.delete("/{thread_id}")
def delete_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict:
user_id = user.id if user else None
try:
delete_chat_session(
user_id=user_id,
chat_session_id=thread_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
@router.get("")
def list_threads(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
user_id = user.id if user else None
chat_sessions = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
)
return ChatSessionsResponse(
sessions=[
ChatSessionDetails(
id=chat.id,
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
current_alternate_model=chat.current_alternate_model,
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]
)

View File

@@ -78,7 +78,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import CreateChatSessionID
from onyx.server.query_and_chat.models import LLMOverride
from onyx.server.query_and_chat.models import PromptOverride
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SearchFeedbackRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
@@ -636,13 +635,9 @@ class ChatSeedRequest(BaseModel):
# overrides / seeding
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
description: str | None = None
message: str | None = None
# TODO: support this
# initial_message_retrieval_options: RetrievalDetails | None = None
class ChatSeedResponse(BaseModel):
redirect_url: str
@@ -666,7 +661,6 @@ def seed_chat(
user_id=None, # this chat session is "unassigned" until a user visits the web UI
persona_id=chat_seed_request.persona_id,
llm_override=chat_seed_request.llm_override,
prompt_override=chat_seed_request.prompt_override,
)
except Exception as e:
logger.exception(e)

View File

@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
@@ -28,7 +26,6 @@ from onyx.db.enums import ChatSessionSharedStatus
from onyx.db.models import ChatSession
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import Packet
@@ -81,61 +78,30 @@ class ChatFeedbackRequest(BaseModel):
class CreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id"""
# NOTE: Double check before adding fields to this class, it has historically gotten really
# bloated and hard to maintain.
# Identifying where the message is in the chat session history
chat_session_id: UUID
# This is the primary-key (unique identifier) for the previous message of the tree
parent_message_id: int | None
# New message contents
message: str
filters: BaseFilters | None = None
# Files that we should attach to this message
file_descriptors: list[FileDescriptor] = []
# Prompts are embedded in personas, so no separate prompt_id needed
# If search_doc_ids provided, it should use those docs explicitly
search_doc_ids: list[int] | None
retrieval_options: RetrievalDetails | None
# Useable via the APIs but not recommended for most flows
rerank_settings: RerankingDetails | None = None
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
# TODO: this is for the selecting documents functionality
search_doc_ids: list[int] | None = None
# enables additional handling to ensure that we regenerate with a given user message ID
regenerate: bool | None = None
# allows the caller to override the Persona / Prompt
# these do not persist in the chat thread details
# Let's the message be processed with some different LLM than the usual
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# Allows the caller to override the temperature for the chat session
# this does persist in the chat thread details
temperature_override: float | None = None
# allow user to specify an alternate assistant
alternate_assistant_id: int | None = None
# This takes the priority over the prompt_override
# This won't be a type that's passed in directly from the API
persona_override_config: PersonaOverrideConfig | None = None
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If true, ignores most of the search options and uses pro search instead.
# TODO: decide how many of the above options we want to pass through to pro search
use_agentic_search: bool = False
skip_gen_ai_answer_generation: bool = False
# List of allowed tool IDs to restrict tool usage. If not provided, all tools available to the persona will be used.
allowed_tool_ids: list[int] | None = None
@@ -143,13 +109,13 @@ class CreateChatMessageRequest(ChunkContext):
# TODO: make this a single one since unclear how to force this for multiple at a time.
forced_tool_ids: list[int] | None = None
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
raise ValueError(
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
)
return self
# NOTE: the fields below are less used and typically should not be set in normal flows.
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
data = super().model_dump(*args, **kwargs)
@@ -332,33 +298,9 @@ class DocumentSearchRequest(ChunkContext):
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits
# Easier APIs to work with for developers
persona_override_config: PersonaOverrideConfig | None = None
persona_id: int | None = None
persona_id: int
messages: list[ThreadMessage]
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generating an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:
raise ValueError("Exactly one of persona_config or persona_id must be set")
elif self.persona_override_config is not None and (self.persona_id is not None):
raise ValueError(
"If persona_override_config is set, persona_id cannot be set"
)
return self
filters: BaseFilters | None = None
class OneShotQAResponse(BaseModel):

View File

@@ -13,7 +13,6 @@ from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.models import AnswerStream
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -41,7 +40,6 @@ from onyx.db.chat import get_valid_messages_from_query_sessions
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.search_settings import get_current_search_settings
@@ -212,22 +210,12 @@ def get_answer_stream(
query = query_request.messages[0].message
logger.notice(f"Received query for Answer API: {query}")
if (
query_request.persona_override_config is None
and query_request.persona_id is None
):
raise KeyError("Must provide persona ID or Persona Config")
persona_info: Persona | PersonaOverrideConfig | None = None
if query_request.persona_override_config is not None:
persona_info = query_request.persona_override_config
elif query_request.persona_id is not None:
persona_info = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
persona_info = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
@@ -249,15 +237,11 @@ def get_answer_stream(
# Also creates a new chat session
request = prepare_chat_message_request(
message_text=combined_message,
filters=query_request.filters,
user=user,
persona_id=query_request.persona_id,
persona_override_config=query_request.persona_override_config,
message_ts_to_respond_to=None,
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -166,3 +167,56 @@ class ToolCallInfo(BaseModel):
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
class BaseCiteableToolResult(BaseModel):
"""Base class for tool results that can be cited."""
document_citation_number: int
unique_identifier_to_strip_away: str | None = None
type: str
class LlmInternalSearchResult(BaseCiteableToolResult):
"""Result from an internal search query"""
type: Literal["internal_search"] = "internal_search"
title: str
excerpt: str
metadata: dict[str, Any]
class LlmWebSearchResult(BaseCiteableToolResult):
"""Result from a web search query"""
type: Literal["web_search"] = "web_search"
url: str
title: str
snippet: str
class LlmOpenUrlResult(BaseCiteableToolResult):
"""Result from opening/fetching a URL"""
type: Literal["open_url"] = "open_url"
content: str
class PythonExecutionFile(BaseModel):
"""File generated during Python execution"""
filename: str
file_link: str
class LlmPythonExecutionResult(BaseModel):
"""Result from Python code execution"""
type: Literal["python_execution"] = "python_execution"
stdout: str
stderr: str
exit_code: int | None
timed_out: bool
generated_files: list[PythonExecutionFile]
error: str | None = None

View File

@@ -17,17 +17,17 @@ from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.tools.models import (
LlmPythonExecutionResult,
)
from onyx.tools.models import PythonExecutionFile
from onyx.tools.models import PythonToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations_v2.code_interpreter_client import FileInput
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmPythonExecutionResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import PythonExecutionFile
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.utils.logger import setup_logger

View File

@@ -1,59 +0,0 @@
"""Base models for tool results with citation support."""
from typing import Any
from typing import Literal
from pydantic import BaseModel
class BaseCiteableToolResult(BaseModel):
"""Base class for tool results that can be cited."""
document_citation_number: int
unique_identifier_to_strip_away: str | None = None
type: str
class LlmInternalSearchResult(BaseCiteableToolResult):
"""Result from an internal search query"""
type: Literal["internal_search"] = "internal_search"
title: str
excerpt: str
metadata: dict[str, Any]
class LlmWebSearchResult(BaseCiteableToolResult):
"""Result from a web search query"""
type: Literal["web_search"] = "web_search"
url: str
title: str
snippet: str
class LlmOpenUrlResult(BaseCiteableToolResult):
"""Result from opening/fetching a URL"""
type: Literal["open_url"] = "open_url"
content: str
class PythonExecutionFile(BaseModel):
"""File generated during Python execution"""
filename: str
file_link: str
class LlmPythonExecutionResult(BaseModel):
"""Result from Python code execution"""
type: Literal["python_execution"] = "python_execution"
stdout: str
stderr: str
exit_code: int | None
timed_out: bool
generated_files: list[PythonExecutionFile]
error: str | None = None

View File

@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_chat_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import remove_llm_provider
@@ -73,9 +72,6 @@ def test_answer_with_only_anthropic_provider(
chat_session_id=chat_session.id,
parent_message_id=None,
message="hello",
file_descriptors=[],
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
)
response_stream: list[AnswerStreamPart] = []

View File

@@ -7,7 +7,6 @@ from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_chat_session
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -45,10 +44,7 @@ def test_stream_chat_current_date_response(
parent_message_id=None,
message="Please respond only with the current date in the format 'Weekday Month DD, YYYY'.",
file_descriptors=[],
prompt_override=None,
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
query_override=None,
filters=None,
)
gen = stream_chat_message_objects(

View File

@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_chat_session
from onyx.db.models import RecencyBiasSetting
from onyx.db.models import User
@@ -104,11 +103,6 @@ def test_stream_chat_message_objects_without_web_search(
chat_session_id=chat_session.id,
parent_message_id=None,
message="run a web search for 'Onyx'",
file_descriptors=[],
prompt_override=None,
search_doc_ids=None,
retrieval_options=RetrievalDetails(),
query_override=None,
)
# Call stream_chat_message_objects
response_generator = stream_chat_message_objects(

View File

@@ -8,12 +8,10 @@ from uuid import UUID
import requests
from requests.models import Response
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import StreamingType
@@ -92,14 +90,8 @@ class ChatSessionManager:
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
search_doc_ids: list[int] | None = None,
retrieval_options: RetrievalDetails | None = None,
query_override: str | None = None,
regenerate: bool | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
use_agentic_search: bool = False,
forced_tool_ids: list[int] | None = None,
chat_session: DATestChatSession | None = None,
) -> StreamedResponse:
@@ -109,15 +101,8 @@ class ChatSessionManager:
message=message,
file_descriptors=file_descriptors or [],
search_doc_ids=search_doc_ids or [],
retrieval_options=retrieval_options,
rerank_settings=None, # Can be added if needed
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
use_agentic_search=use_agentic_search,
forced_tool_ids=forced_tool_ids,
)

View File

@@ -1,23 +0,0 @@
from typing import Optional
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
# Create a thread to use in the tests
response = requests.post(
f"{BASE_URL}/threads", # Updated endpoint path
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return UUID(response.json()["id"])

View File

@@ -1,151 +0,0 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
def test_create_assistant(admin_user: DATestUser | None) -> None:
response = requests.post(
ASSISTANTS_URL,
json={
"model": "gpt-3.5-turbo",
"name": "Test Assistant",
"description": "A test assistant",
"instructions": "You are a helpful assistant.",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Assistant"
assert data["description"] == "A test assistant"
assert data["model"] == "gpt-3.5-turbo"
assert data["instructions"] == "You are a helpful assistant."
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, retrieve the assistant
response = requests.get(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Retrieve Test"
def test_modify_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, modify the assistant
response = requests.post(
f"{ASSISTANTS_URL}/{assistant_id}",
json={"name": "Modified Assistant", "instructions": "New instructions"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Modified Assistant"
assert data["instructions"] == "New instructions"
def test_delete_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, delete the assistant
response = requests.delete(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["deleted"] is True
def test_list_assistants(admin_user: DATestUser | None) -> None:
# Create multiple assistants
for i in range(3):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the assistants
response = requests.get(
ASSISTANTS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) >= 3 # At least the 3 we just created
assert all(assistant["object"] == "assistant" for assistant in data["data"])
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
# Create 5 assistants
for i in range(5):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# List assistants with limit
response = requests.get(
f"{ASSISTANTS_URL}?limit=2",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
assert data["has_more"] is True
# Get next page
before = data["last_id"]
response = requests.get(
f"{ASSISTANTS_URL}?limit=2&before={before}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
non_existent_id = -99
response = requests.get(
f"{ASSISTANTS_URL}/{non_existent_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -1,133 +0,0 @@
import uuid
from typing import Optional
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> str:
response = requests.post(
BASE_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
response = requests.post(
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
json={
"role": "user",
"content": "Hello, world!",
"file_ids": [],
"metadata": {"key": "value"},
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
assert response_json["metadata"] == {"key": "value"}
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the messages
response = requests.get(
f"{BASE_URL}/{thread_id}/messages",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["object"] == "list"
assert isinstance(response_json["data"], list)
assert len(response_json["data"]) > 0
assert "first_id" in response_json
assert "last_id" in response_json
assert "has_more" in response_json
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, retrieve the message
response = requests.get(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, modify the message
response = requests.post(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
non_existent_thread_id = str(uuid.uuid4())
non_existent_message_id = -99
# Test with non-existent thread
response = requests.post(
f"{BASE_URL}/{non_existent_thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404
# Test with non-existent message
response = requests.get(
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -1,137 +0,0 @@
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
"""Create a run and return its ID."""
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_run(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
"model": "gpt-3.5-turbo",
"instructions": "Test instructions",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert response_json["assistant_id"] == 0
assert UUID(response_json["thread_id"]) == thread_id
assert response_json["status"] == "queued"
assert response_json["model"] == "gpt-3.5-turbo"
assert response_json["instructions"] == "Test instructions"
def test_retrieve_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
retrieve_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == run_id
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert UUID(response_json["thread_id"]) == thread_id
def test_cancel_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
cancel_response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert cancel_response.status_code == 200
response_json = cancel_response.json()
assert response_json["id"] == run_id
assert response_json["status"] == "cancelled"
def test_list_runs(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
# Create a few runs
for _ in range(3):
requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the runs
list_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert isinstance(response_json, list)
assert len(response_json) >= 3
for run in response_json:
assert "id" in run
assert run["object"] == "thread.run"
assert "created_at" in run
assert UUID(run["thread_id"]) == thread_id
assert "status" in run
assert "model" in run
def test_list_run_steps(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
steps_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert steps_response.status_code == 200
response_json = steps_response.json()
assert isinstance(response_json, list)
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
assert len(response_json) == 0

View File

@@ -1,131 +0,0 @@
from uuid import UUID
import requests
from onyx.db.models import ChatSessionSharedStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
def test_create_thread(admin_user: DATestUser | None) -> None:
response = requests.post(
THREADS_URL,
json={"messages": None, "metadata": {"key": "value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread"
assert "created_at" in response_json
assert response_json["metadata"] == {"key": "value"}
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, retrieve the thread
retrieve_response = requests.get(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread"
assert "created_at" in response_json
def test_modify_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, modify the thread
modify_response = requests.post(
f"{THREADS_URL}/{thread_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert modify_response.status_code == 200
response_json = modify_response.json()
assert response_json["id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_delete_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, delete the thread
delete_response = requests.delete(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert delete_response.status_code == 200
response_json = delete_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread.deleted"
assert response_json["deleted"] is True
def test_list_threads(admin_user: DATestUser | None) -> None:
# Create a few threads
for _ in range(3):
requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the threads
list_response = requests.get(
THREADS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert "sessions" in response_json
assert len(response_json["sessions"]) >= 3
for session in response_json["sessions"]:
assert "id" in session
assert "name" in session
assert "persona_id" in session
assert "time_created" in session
assert "shared_status" in session
assert "current_alternate_model" in session
# Validate UUID
UUID(session["id"])
# Validate shared_status
assert session["shared_status"] in [
status.value for status in ChatSessionSharedStatus
]

View File

@@ -155,7 +155,6 @@ def test_soft_delete_with_agentic_search(
chat_session_id=test_chat_session.id,
message="What are the key principles of software engineering?",
user_performing_action=basic_user,
use_agentic_search=True,
)
# Verify that the message was processed successfully
@@ -206,7 +205,6 @@ def test_hard_delete_with_agentic_search(
chat_session_id=test_chat_session.id,
message="What are the key principles of software engineering?",
user_performing_action=basic_user,
use_agentic_search=True,
)
# Verify that the message was processed successfully

View File

@@ -5,9 +5,7 @@ from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.connectors.models import InputType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import RetrievalDetails
from onyx.db.enums import IndexingStatus
from onyx.server.documents.models import ConnectorBase
from onyx.server.query_and_chat.models import OneShotQARequest
@@ -24,9 +22,7 @@ def _api_url_builder(env_name: str, api_path: str) -> str:
@retry(tries=5, delay=5)
def get_answer_from_query(
query: str, only_retrieve_docs: bool, env_name: str
) -> tuple[list[str], str]:
def get_answer_from_query(query: str, env_name: str) -> tuple[list[str], str]:
filters = IndexFilters(
source_type=None,
document_set=None,
@@ -40,13 +36,7 @@ def get_answer_from_query(
new_message_request = OneShotQARequest(
messages=messages,
persona_id=0,
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
),
skip_gen_ai_answer_generation=only_retrieve_docs,
filters=filters,
)
url = _api_url_builder(env_name, "/query/answer-with-citation/")

View File

@@ -116,7 +116,6 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
query = question_data["question"]
context_data_list, answer = get_answer_from_query(
query=query,
only_retrieve_docs=config["only_retrieve_docs"],
env_name=config["env_name"],
)
print(f"On question number {question_number}")

View File

@@ -45,9 +45,7 @@ from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.configs.constants import MessageType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -432,14 +430,7 @@ class SearchAnswerAnalyzer:
qa_request = OneShotQARequest(
messages=messages,
persona_id=0, # default persona
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
limit=self.config.max_search_results,
),
skip_gen_ai_answer_generation=self.config.search_only,
filters=filters,
)
# send the request

View File

@@ -130,16 +130,9 @@ async function* sendMessage({
chat_session_id: chatSessionId,
parent_message_id: parentMessageId || null,
message: message,
prompt_id: null,
search_doc_ids: null,
file_descriptors: [],
// checkout https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/search/models.py#L105 for
// all available options
retrieval_options: {
run_search: "always",
filters: null,
},
query_override: null,
filters: null,
}),
});
if (!sendMessageResponse.ok) {

View File

@@ -174,8 +174,6 @@ export default function ChatPage({
if (message) {
onSubmit({
message,
currentMessageFiles,
useAgentSearch: deepResearchEnabled,
});
}
}
@@ -603,8 +601,6 @@ export default function ChatPage({
// We call onSubmit, passing a `messageOverride`
onSubmit({
message: lastUserMsg.message,
currentMessageFiles: currentMessageFiles,
useAgentSearch: deepResearchEnabled,
messageIdToResend: lastUserMsg.messageId,
});
}
@@ -632,11 +628,9 @@ export default function ChatPage({
const handleChatInputSubmit = useCallback(() => {
onSubmit({
message: message,
currentMessageFiles: currentMessageFiles,
useAgentSearch: deepResearchEnabled,
});
setShowOnboarding(false);
}, [message, onSubmit, currentMessageFiles, deepResearchEnabled]);
}, [message, onSubmit]);
// Memoized callbacks for DocumentResults
const handleMobileDocumentSidebarClose = useCallback(() => {

View File

@@ -22,16 +22,11 @@ interface MessagesDisplayProps {
onSubmit: (args: {
message: string;
messageIdToResend?: number;
currentMessageFiles: ProjectFile[];
useAgentSearch: boolean;
modelOverride?: LlmDescriptor;
regenerationRequest?: {
messageId: number;
parentMessage: Message;
forceSearch?: boolean;
};
forceSearch?: boolean;
queryOverride?: string;
isSeededChat?: boolean;
overrideFileDescriptors?: FileDescriptor[];
}) => Promise<void>;
@@ -77,24 +72,17 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
const emptyDocs = useMemo<OnyxDocument[]>(() => [], []);
const emptyChildrenIds = useMemo<number[]>(() => [], []);
const createRegenerator = useCallback(
(regenerationRequest: {
messageId: number;
parentMessage: Message;
forceSearch?: boolean;
}) => {
(regenerationRequest: { messageId: number; parentMessage: Message }) => {
return async function (modelOverride: LlmDescriptor) {
return await onSubmit({
message: regenerationRequest.parentMessage.message,
currentMessageFiles,
useAgentSearch: deepResearchEnabled,
modelOverride,
messageIdToResend: regenerationRequest.parentMessage.messageId,
regenerationRequest,
forceSearch: regenerationRequest.forceSearch,
});
};
},
[onSubmit, deepResearchEnabled, currentMessageFiles]
[onSubmit]
);
const handleEditWithMessageId = useCallback(
@@ -102,11 +90,9 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
onSubmit({
message: editedContent,
messageIdToResend: msgId,
currentMessageFiles: [],
useAgentSearch: deepResearchEnabled,
});
},
[onSubmit, deepResearchEnabled]
[onSubmit]
);
// require assistant to be present before rendering

View File

@@ -81,16 +81,8 @@ const SYSTEM_MESSAGE_ID = -3;
export interface OnSubmitProps {
message: string;
//from chat input bar
currentMessageFiles: ProjectFile[];
// from the chat bar???
useAgentSearch: boolean;
// optional params
messageIdToResend?: number;
queryOverride?: string;
forceSearch?: boolean;
isSeededChat?: boolean;
modelOverride?: LlmDescriptor;
regenerationRequest?: RegenerationRequest | null;
@@ -100,7 +92,6 @@ export interface OnSubmitProps {
interface RegenerationRequest {
messageId: number;
parentMessage: Message;
forceSearch?: boolean;
}
interface UseChatControllerProps {
@@ -142,8 +133,13 @@ export function useChatController({
const { refreshChatSessions } = useChatSessions();
const { assistantPreferences } = useAssistantPreferences();
const { forcedToolIds } = useForcedTools();
const { fetchProjects, uploadFiles, setCurrentMessageFiles, beginUpload } =
useProjectsContext();
const {
fetchProjects,
uploadFiles,
setCurrentMessageFiles,
beginUpload,
currentMessageFiles,
} = useProjectsContext();
const posthog = usePostHog();
// Use selectors to access only the specific fields we need
@@ -387,11 +383,7 @@ export function useChatController({
const onSubmit = useCallback(
async ({
message,
currentMessageFiles,
useAgentSearch,
messageIdToResend,
queryOverride,
forceSearch,
isSeededChat,
modelOverride,
regenerationRequest,
@@ -663,7 +655,6 @@ export function useChatController({
updateCurrentMessageFIFO(stack, {
signal: controller.signal,
message: currMessage,
alternateAssistantId: liveAssistant?.id,
fileDescriptors: overrideFileDescriptors,
parentMessageId: (() => {
const parentId =
@@ -687,15 +678,6 @@ export function useChatController({
document.db_doc_id !== undefined && document.db_doc_id !== null
)
.map((document) => document.db_doc_id as number),
queryOverride,
forceSearch,
currentMessageFiles: currentMessageFiles.map((file) => ({
id: file.file_id,
type: file.chat_file_type,
name: file.name,
user_file_id: file.id,
})),
regenerate: regenerationRequest !== undefined,
modelProvider:
modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion:
@@ -704,10 +686,7 @@ export function useChatController({
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmManager.temperature || undefined,
systemPromptOverride:
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
useAgentSearch,
enabledToolIds:
disabledToolIds && liveAssistant
? liveAssistant.tools

View File

@@ -57,8 +57,6 @@ interface UseChatSessionControllerProps {
refreshChatSessions: () => void;
onSubmit: (params: {
message: string;
currentMessageFiles: ProjectFile[];
useAgentSearch: boolean;
isSeededChat?: boolean;
}) => Promise<void>;
}
@@ -170,8 +168,6 @@ export function useChatSessionController({
submitOnLoadPerformed.current = true;
await onSubmit({
message: firstMessage || "",
currentMessageFiles: [],
useAgentSearch: false,
});
}
return;
@@ -293,8 +289,6 @@ export function useChatSessionController({
await onSubmit({
message: seededMessage,
isSeededChat: true,
currentMessageFiles: [],
useAgentSearch: false,
});
// Force re-name if the chat session doesn't have one
if (!chatSession.description) {

View File

@@ -160,78 +160,45 @@ export type PacketType =
| Packet;
export interface SendMessageParams {
regenerate: boolean;
message: string;
fileDescriptors?: FileDescriptor[];
parentMessageId: number | null;
chatSessionId: string;
filters: Filters | null;
selectedDocumentIds: number[] | null;
queryOverride?: string;
forceSearch?: boolean;
modelProvider?: string;
modelVersion?: string;
temperature?: number;
systemPromptOverride?: string;
useExistingUserMessage?: boolean;
alternateAssistantId?: number;
signal?: AbortSignal;
currentMessageFiles?: FileDescriptor[];
useAgentSearch?: boolean;
enabledToolIds?: number[];
forcedToolIds?: number[];
}
export async function* sendMessage({
regenerate,
message,
fileDescriptors,
currentMessageFiles,
parentMessageId,
chatSessionId,
filters,
selectedDocumentIds,
queryOverride,
forceSearch,
modelProvider,
modelVersion,
temperature,
systemPromptOverride,
useExistingUserMessage,
alternateAssistantId,
signal,
useAgentSearch,
enabledToolIds,
forcedToolIds,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
const payload = {
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
parent_message_id: parentMessageId,
message: message,
// just use the default prompt for the assistant.
// should remove this in the future, as we don't support multiple prompts for a
// single assistant anyways
prompt_id: null,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
current_message_files: currentMessageFiles,
regenerate,
retrieval_options: !documentsAreSelected
? {
run_search: queryOverride || forceSearch ? "always" : "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
filters: !documentsAreSelected ? filters : null,
llm_override:
temperature || modelVersion
? {
@@ -241,7 +208,6 @@ export async function* sendMessage({
}
: null,
use_existing_user_message: useExistingUserMessage,
use_agentic_search: useAgentSearch ?? false,
allowed_tool_ids: enabledToolIds,
forced_tool_ids: forcedToolIds,
};

View File

@@ -22,8 +22,6 @@ export function Suggestions({ onSubmit }: SuggestionsProps) {
const handleSuggestionClick = (suggestion: string) => {
onSubmit({
message: suggestion,
currentMessageFiles: [],
useAgentSearch: false,
});
};