mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
2 Commits
v2.11.0-cl
...
interfaces
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3921d105dc | ||
|
|
008e0021f1 |
@@ -15,8 +15,6 @@ from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -77,40 +75,17 @@ def handle_simplified_chat_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -123,8 +98,7 @@ def handle_send_message_simple_with_history(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
takes in chat history maintained by the caller"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
@@ -194,41 +168,22 @@ def handle_send_message_simple_with_history(
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = req.query_override or thread_based_query_rephrase(
|
||||
rephrased_query = thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=rephrased_query,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
message=rephrased_query,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -43,20 +42,10 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
@@ -68,16 +57,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
||||
@@ -4,11 +4,9 @@ from collections.abc import Callable
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_processing_task,
|
||||
)
|
||||
@@ -17,24 +15,19 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
)
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -51,9 +44,6 @@ from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -64,15 +54,10 @@ logger = setup_logger()
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
filters: BaseFilters | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
@@ -91,15 +76,7 @@ def prepare_chat_message_request(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
filters=filters,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
)
|
||||
@@ -355,68 +332,69 @@ def extract_headers(
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
# TODO in case it needs to be referenced later
|
||||
# def create_temporary_persona(
|
||||
# persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
# ) -> Persona:
|
||||
# if not is_user_admin(user):
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="User is not authorized to create a persona in one shot queries",
|
||||
# )
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
# """Create a temporary Persona object from the provided configuration."""
|
||||
# persona = Persona(
|
||||
# name=persona_config.name,
|
||||
# description=persona_config.description,
|
||||
# num_chunks=persona_config.num_chunks,
|
||||
# llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
# llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
# recency_bias=persona_config.recency_bias,
|
||||
# llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
# llm_model_version_override=persona_config.llm_model_version_override,
|
||||
# )
|
||||
|
||||
if persona_config.prompts:
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
# if persona_config.prompts:
|
||||
# # Use the first prompt from the override config for embedded prompt fields
|
||||
# first_prompt = persona_config.prompts[0]
|
||||
# persona.system_prompt = first_prompt.system_prompt
|
||||
# persona.task_prompt = first_prompt.task_prompt
|
||||
# persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
# persona.tools = []
|
||||
# if persona_config.custom_tools_openapi:
|
||||
# from onyx.chat.emitter import get_default_emitter
|
||||
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=schema,
|
||||
emitter=get_default_emitter(),
|
||||
),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
# for schema in persona_config.custom_tools_openapi:
|
||||
# tools = cast(
|
||||
# list[Tool],
|
||||
# build_custom_tools_from_openapi_schema_and_headers(
|
||||
# tool_id=0, # dummy tool id
|
||||
# openapi_schema=schema,
|
||||
# emitter=get_default_emitter(),
|
||||
# ),
|
||||
# )
|
||||
# persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
# if persona_config.tools:
|
||||
# tool_ids = [tool.id for tool in persona_config.tools]
|
||||
# persona.tools.extend(
|
||||
# fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
# )
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
# if persona_config.tool_ids:
|
||||
# persona.tools.extend(
|
||||
# fetch_existing_tools(
|
||||
# db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
# )
|
||||
# )
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
# fetched_docs = fetch_existing_doc_sets(
|
||||
# db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
# )
|
||||
# persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
# return persona
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
|
||||
@@ -5,12 +5,10 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -127,35 +125,6 @@ class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
|
||||
@@ -26,8 +26,6 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
@@ -260,17 +258,10 @@ def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
@@ -298,10 +289,7 @@ def stream_chat_message_objects(
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
new_msg_req.alternate_assistant_id
|
||||
user_selected_filters = retrieval_options.filters if retrieval_options else None
|
||||
user_selected_filters = new_msg_req.filters
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
@@ -316,11 +304,6 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm.session import SessionTransaction
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.evals.models import EvalationAck
|
||||
@@ -73,15 +72,11 @@ def _get_answer(
|
||||
request = prepare_chat_message_request(
|
||||
message_text=eval_input["message"],
|
||||
user=user,
|
||||
filters=None,
|
||||
persona_id=None,
|
||||
persona_override_config=full_configuration.persona_override_config,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=RetrievalDetails(),
|
||||
rerank_settings=None,
|
||||
db_session=db_session,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
llm_override=full_configuration.llm,
|
||||
use_agentic_search=False,
|
||||
allowed_tool_ids=full_configuration.allowed_tool_ids,
|
||||
)
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -6,9 +6,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import PromptOverrideConfig
|
||||
from onyx.chat.models import ToolConfig
|
||||
from onyx.db.tools import get_builtin_tool
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
|
||||
@@ -16,7 +13,6 @@ from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
|
||||
|
||||
class EvalConfiguration(BaseModel):
|
||||
builtin_tool_types: list[str] = Field(default_factory=list)
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = Field(default_factory=LLMOverride)
|
||||
search_permissions_email: str | None = None
|
||||
allowed_tool_ids: list[int]
|
||||
@@ -24,7 +20,6 @@ class EvalConfiguration(BaseModel):
|
||||
|
||||
class EvalConfigurationOptions(BaseModel):
|
||||
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider="Default",
|
||||
model_version="gpt-4.1",
|
||||
@@ -35,25 +30,7 @@ class EvalConfigurationOptions(BaseModel):
|
||||
no_send_logs: bool = False
|
||||
|
||||
def get_configuration(self, db_session: Session) -> EvalConfiguration:
|
||||
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
|
||||
name="Eval",
|
||||
description="A persona for evaluation",
|
||||
tools=[
|
||||
ToolConfig(id=get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id)
|
||||
for tool in self.builtin_tool_types
|
||||
],
|
||||
prompts=[
|
||||
PromptOverrideConfig(
|
||||
name="Default",
|
||||
description="Default prompt for evaluation",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
return EvalConfiguration(
|
||||
persona_override_config=persona_override_config,
|
||||
llm=self.llm,
|
||||
search_permissions_email=self.search_permissions_email,
|
||||
allowed_tool_ids=[
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -110,7 +109,7 @@ def get_llm_config_for_persona(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona | None,
|
||||
user: User | None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
@@ -137,22 +136,18 @@ def get_llms_for_persona(
|
||||
if not provider_model:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
# Only check access control for database Persona entities, not PersonaOverrideConfig
|
||||
# PersonaOverrideConfig is used for temporary overrides and doesn't have access restrictions
|
||||
persona_model = persona if isinstance(persona, Persona) else None
|
||||
|
||||
# Fetch user group IDs for access control check
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
|
||||
if not can_user_access_llm_provider(
|
||||
provider_model,
|
||||
user_group_ids,
|
||||
persona_model,
|
||||
persona,
|
||||
):
|
||||
logger.warning(
|
||||
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
|
||||
getattr(user, "id", None),
|
||||
getattr(persona_model, "id", None),
|
||||
getattr(persona, "id", None),
|
||||
provider_model.name,
|
||||
)
|
||||
return get_default_llms(
|
||||
|
||||
@@ -113,9 +113,6 @@ from onyx.server.middleware.rate_limiting import close_auth_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from onyx.server.pat.api import router as pat_router
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.chat_backend_v0 import router as chat_v0_router
|
||||
@@ -411,9 +408,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_oauth_router)
|
||||
|
||||
@@ -19,9 +19,7 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -210,31 +208,13 @@ def handle_regular_answer(
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = slack_channel_config.enable_auto_filters
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
filters=filters,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,12 @@ from onyx.server.features.web_search.models import WebSearchToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchWithContentResponse
|
||||
from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderView
|
||||
from onyx.tools.models import (
|
||||
LlmOpenUrlResult,
|
||||
)
|
||||
from onyx.tools.models import (
|
||||
LlmWebSearchResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
OnyxWebCrawler,
|
||||
@@ -30,12 +36,6 @@ from onyx.tools.tool_implementations.web_search.providers import (
|
||||
from onyx.tools.tool_implementations.web_search.utils import (
|
||||
truncate_search_result_content,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmOpenUrlResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmWebSearchResult,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
@@ -2,10 +2,10 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
from onyx.tools.models import (
|
||||
LlmOpenUrlResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
from onyx.tools.models import (
|
||||
LlmWebSearchResult,
|
||||
)
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import get_raw_personas_for_user
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
|
||||
|
||||
# Base models
|
||||
class AssistantObject(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant"
|
||||
created_at: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
tools: list[dict[str, Any]]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
icon_name: Optional[str] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModifyAssistantRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ListAssistantsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[AssistantObject]
|
||||
first_id: Optional[int] = None
|
||||
last_id: Optional[int] = None
|
||||
has_more: bool
|
||||
|
||||
|
||||
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||
return AssistantObject(
|
||||
id=persona.id,
|
||||
created_at=0,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||
instructions=persona.system_prompt,
|
||||
tools=[
|
||||
{
|
||||
"type": tool.display_name,
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"schema": tool.openapi_schema,
|
||||
},
|
||||
}
|
||||
for tool in persona.tools
|
||||
],
|
||||
file_ids=[], # Assuming no file support for now
|
||||
metadata={}, # Assuming no metadata for now
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("")
|
||||
def create_assistant(
|
||||
request: CreateAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
# No separate Prompt entity; instructions map to persona.system_prompt
|
||||
|
||||
tool_ids = []
|
||||
for tool in request.tools or []:
|
||||
tool_type = tool.get("type")
|
||||
if not tool_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_db = get_tool_by_name(tool_type, db_session)
|
||||
tool_ids.append(tool_db.id)
|
||||
except ValueError:
|
||||
# Skip tools that don't exist in the database
|
||||
logger.error(f"Tool {tool_type} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||
)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=request.name or f"Assistant-{uuid4()}",
|
||||
description=request.description or "",
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=request.model,
|
||||
starter_messages=None,
|
||||
is_public=False,
|
||||
db_session=db_session,
|
||||
document_set_ids=[],
|
||||
tool_ids=tool_ids,
|
||||
icon_name=request.icon_name,
|
||||
is_visible=True,
|
||||
system_prompt=request.instructions or "",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.get("/{assistant_id}")
|
||||
def retrieve_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
try:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
except ValueError:
|
||||
persona = None
|
||||
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}")
|
||||
def modify_assistant(
|
||||
assistant_id: int,
|
||||
request: ModifyAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=True,
|
||||
)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(persona, key, value)
|
||||
|
||||
if "instructions" in update_data:
|
||||
persona.system_prompt = update_data["instructions"]
|
||||
|
||||
db_session.commit()
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}")
|
||||
def delete_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DeleteAssistantResponse:
|
||||
try:
|
||||
mark_persona_as_deleted(
|
||||
persona_id=int(assistant_id),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
limit: int = Query(20, le=100),
|
||||
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||
after: Optional[int] = None,
|
||||
before: Optional[int] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListAssistantsResponse:
|
||||
persona_snapshots = list(
|
||||
get_raw_personas_for_user(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
persona_snapshots = [p for p in persona_snapshots if p.id > int(after)]
|
||||
if before:
|
||||
persona_snapshots = [p for p in persona_snapshots if p.id < int(before)]
|
||||
|
||||
# Apply ordering
|
||||
persona_snapshots.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
persona_snapshots = persona_snapshots[:limit]
|
||||
|
||||
assistants = [persona_to_assistant(p) for p in persona_snapshots]
|
||||
|
||||
return ListAssistantsResponse(
|
||||
data=assistants,
|
||||
first_id=assistants[0].id if assistants else None,
|
||||
last_id=assistants[-1].id if assistants else None,
|
||||
has_more=len(persona_snapshots) == limit,
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from onyx.server.openai_assistants_api.assistants_api import (
|
||||
router as assistants_router,
|
||||
)
|
||||
from onyx.server.openai_assistants_api.messages_api import router as messages_router
|
||||
from onyx.server.openai_assistants_api.runs_api import router as runs_router
|
||||
from onyx.server.openai_assistants_api.threads_api import router as threads_router
|
||||
|
||||
|
||||
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/openai-assistants")
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(runs_router)
|
||||
router.include_router(threads_router)
|
||||
router.include_router(messages_router)
|
||||
|
||||
return router
|
||||
@@ -1,234 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
|
||||
router = APIRouter(prefix="")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||
object: Literal["thread.message"] = "thread.message"
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
thread_id: str
|
||||
role: Role
|
||||
content: list[MessageContent]
|
||||
file_ids: list[str] = []
|
||||
assistant_id: Optional[str] = None
|
||||
run_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
file_ids: list[str] = []
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: list[Message]
|
||||
first_id: str
|
||||
last_id: str
|
||||
has_more: bool
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages")
|
||||
def create_message(
|
||||
thread_id: str,
|
||||
message: CreateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message=message.content,
|
||||
token_count=check_number_of_tokens(message.content),
|
||||
message_type=(
|
||||
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return Message(
|
||||
id=str(new_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user",
|
||||
content=[MessageContent(type="text", text=message.content)],
|
||||
file_ids=message.file_ids,
|
||||
metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
def list_messages(
|
||||
thread_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListMessagesResponse:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
messages = [m for m in messages if str(m.id) >= after]
|
||||
if before:
|
||||
messages = [m for m in messages if str(m.id) <= before]
|
||||
|
||||
# Apply ordering
|
||||
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
messages = messages[:limit]
|
||||
|
||||
data = [
|
||||
Message(
|
||||
id=str(m.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if m.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=m.message)],
|
||||
created_at=int(m.time_sent.timestamp()),
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
return ListMessagesResponse(
|
||||
data=data,
|
||||
first_id=str(data[0].id) if data else "",
|
||||
last_id=str(data[-1].id) if data else "",
|
||||
has_more=len(messages) == limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||
def retrieve_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||
def modify_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
request: ModifyMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update metadata
|
||||
# TODO: Uncomment this once we have metadata in the chat message
|
||||
# chat_message.metadata = request.metadata
|
||||
# db_session.commit()
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
@@ -1,344 +0,0 @@
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
assistant_id: int
|
||||
model: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
additional_instructions: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
RunStatus = Literal[
|
||||
"queued",
|
||||
"in_progress",
|
||||
"requires_action",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
"failed",
|
||||
"completed",
|
||||
"expired",
|
||||
]
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["thread.run"]
|
||||
created_at: int
|
||||
assistant_id: int
|
||||
thread_id: UUID
|
||||
status: RunStatus
|
||||
started_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
last_error: Optional[dict] = None
|
||||
model: str
|
||||
instructions: str
|
||||
tools: list[dict]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
def process_run_in_background(
|
||||
message_id: int,
|
||||
parent_message_id: int,
|
||||
chat_session_id: UUID,
|
||||
assistant_id: int,
|
||||
instructions: str,
|
||||
tools: list[dict],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Get the latest message in the chat session
|
||||
_ = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_tool_retrieval_details = RetrievalDetails()
|
||||
for tool in tools:
|
||||
if tool["type"] == SearchTool.__name__ and (
|
||||
retrieval_details := tool.get("retrieval_details")
|
||||
):
|
||||
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||
retrieval_details
|
||||
)
|
||||
break
|
||||
|
||||
new_msg_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||
message=instructions,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
prompt_override=None,
|
||||
alternate_assistant_id=assistant_id,
|
||||
use_existing_user_message=True,
|
||||
existing_assistant_message_id=message_id,
|
||||
)
|
||||
|
||||
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||
try:
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
if isinstance(packet, ChatMessageDetail):
|
||||
# Update the run status and message content
|
||||
run_message = get_chat_message(
|
||||
message_id, user.id if user else None, db_session
|
||||
)
|
||||
if run_message:
|
||||
# this handles cancelling
|
||||
if run_message.error:
|
||||
return
|
||||
|
||||
run_message.message = packet.message
|
||||
run_message.message_type = MessageType.ASSISTANT
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Error processing run in background")
|
||||
run_message.error = str(e)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
db_session.refresh(run_message)
|
||||
if run_message.token_count == 0:
|
||||
run_message.error = "No tokens generated"
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs")
|
||||
def create_run(
|
||||
thread_id: UUID,
|
||||
run_request: RunRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
# Create a new "run" (chat message) in the session
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message="",
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.flush()
|
||||
latest_message.latest_child_message = new_message
|
||||
db_session.commit()
|
||||
|
||||
# Schedule the background task
|
||||
background_tasks.add_task(
|
||||
process_run_in_background,
|
||||
new_message.id,
|
||||
latest_message.id,
|
||||
chat_session.id,
|
||||
run_request.assistant_id,
|
||||
run_request.instructions or "",
|
||||
run_request.tools or [],
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
id=str(new_message.id),
|
||||
object="thread.run",
|
||||
created_at=int(new_message.time_sent.timestamp()),
|
||||
assistant_id=run_request.assistant_id,
|
||||
thread_id=chat_session.id,
|
||||
status="queued",
|
||||
model=run_request.model or "default_model",
|
||||
instructions=run_request.instructions or "",
|
||||
tools=run_request.tools or [],
|
||||
file_ids=[],
|
||||
metadata=run_request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||
def retrieve_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=int(run_id), # Convert string run_id to int
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
# Map DAnswer status to OpenAI status
|
||||
run_status: RunStatus = "queued"
|
||||
if chat_message.message:
|
||||
run_status = "in_progress"
|
||||
if chat_message.token_count != 0:
|
||||
run_status = "completed"
|
||||
if chat_message.error:
|
||||
run_status = "cancelled"
|
||||
|
||||
return RunResponse(
|
||||
id=run_id,
|
||||
object="thread.run",
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
assistant_id=chat_session.persona_id or 0,
|
||||
thread_id=chat_session.id,
|
||||
status=run_status,
|
||||
started_at=int(chat_message.time_sent.timestamp()),
|
||||
completed_at=(
|
||||
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||
),
|
||||
model=chat_session.current_alternate_model or "default_model",
|
||||
instructions="", # DAnswer doesn't store per-message instructions
|
||||
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||
file_ids=(
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
),
|
||||
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||
def cancel_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||
# We'll simulate it by marking the message as "cancelled"
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_message.error = "Cancelled"
|
||||
db_session.commit()
|
||||
|
||||
return retrieve_run(thread_id, run_id, user, db_session)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs")
|
||||
def list_runs(
|
||||
thread_id: UUID,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RunResponse]:
|
||||
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
if after:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||
if before:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||
|
||||
# Apply ordering
|
||||
chat_messages = sorted(
|
||||
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
chat_messages = chat_messages[:limit]
|
||||
|
||||
return [
|
||||
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||
def list_run_steps(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]: # You may want to create a specific model for run steps
|
||||
# DAnswer doesn't have an equivalent to run steps
|
||||
# We'll return an empty list to maintain API compatibility
|
||||
return []
|
||||
|
||||
|
||||
# Additional helper functions can be added here if needed
|
||||
@@ -1,157 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter(prefix="/threads")
|
||||
|
||||
|
||||
# Models
|
||||
class Thread(BaseModel):
|
||||
id: UUID
|
||||
object: str = "thread"
|
||||
created_at: int
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[list[dict]] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# API Endpoints
|
||||
@router.post("")
|
||||
def create_thread(
|
||||
request: CreateThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
return Thread(
|
||||
id=new_chat_session.id,
|
||||
created_at=int(new_chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
def retrieve_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}")
|
||||
def modify_thread(
|
||||
thread_id: UUID,
|
||||
request: ModifyThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
description=None, # Not updating description
|
||||
sharing_status=None, # Not updating sharing status
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}")
|
||||
def delete_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_threads(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id if user else None
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat.id,
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
@@ -78,7 +78,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatSessionID
|
||||
from onyx.server.query_and_chat.models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
@@ -636,13 +635,9 @@ class ChatSeedRequest(BaseModel):
|
||||
|
||||
# overrides / seeding
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
description: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
# TODO: support this
|
||||
# initial_message_retrieval_options: RetrievalDetails | None = None
|
||||
|
||||
|
||||
class ChatSeedResponse(BaseModel):
|
||||
redirect_url: str
|
||||
@@ -666,7 +661,6 @@ def seed_chat(
|
||||
user_id=None, # this chat session is "unassigned" until a user visits the web UI
|
||||
persona_id=chat_seed_request.persona_id,
|
||||
llm_override=chat_seed_request.llm_override,
|
||||
prompt_override=chat_seed_request.prompt_override,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -28,7 +26,6 @@ from onyx.db.enums import ChatSessionSharedStatus
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
@@ -81,61 +78,30 @@ class ChatFeedbackRequest(BaseModel):
|
||||
|
||||
|
||||
class CreateChatMessageRequest(ChunkContext):
|
||||
"""Before creating messages, be sure to create a chat_session and get an id"""
|
||||
# NOTE: Double check before adding fields to this class, it has historically gotten really
|
||||
# bloated and hard to maintain.
|
||||
|
||||
# Identifying where the message is in the chat session history
|
||||
chat_session_id: UUID
|
||||
# This is the primary-key (unique identifier) for the previous message of the tree
|
||||
parent_message_id: int | None
|
||||
|
||||
# New message contents
|
||||
message: str
|
||||
filters: BaseFilters | None = None
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor] = []
|
||||
# Prompts are embedded in personas, so no separate prompt_id needed
|
||||
# If search_doc_ids provided, it should use those docs explicitly
|
||||
search_doc_ids: list[int] | None
|
||||
retrieval_options: RetrievalDetails | None
|
||||
# Useable via the APIs but not recommended for most flows
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# TODO: this is for the selecting documents functionality
|
||||
search_doc_ids: list[int] | None = None
|
||||
|
||||
# enables additional handling to ensure that we regenerate with a given user message ID
|
||||
regenerate: bool | None = None
|
||||
|
||||
# allows the caller to override the Persona / Prompt
|
||||
# these do not persist in the chat thread details
|
||||
# Let's the message be processed with some different LLM than the usual
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
# this does persist in the chat thread details
|
||||
temperature_override: float | None = None
|
||||
|
||||
# allow user to specify an alternate assistant
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
# This takes the priority over the prompt_override
|
||||
# This won't be a type that's passed in directly from the API
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If true, ignores most of the search options and uses pro search instead.
|
||||
# TODO: decide how many of the above options we want to pass through to pro search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# List of allowed tool IDs to restrict tool usage. If not provided, all tools available to the persona will be used.
|
||||
allowed_tool_ids: list[int] | None = None
|
||||
|
||||
@@ -143,13 +109,13 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# TODO: make this a single one since unclear how to force this for multiple at a time.
|
||||
forced_tool_ids: list[int] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
raise ValueError(
|
||||
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
|
||||
)
|
||||
return self
|
||||
# NOTE: the fields below are less used and typically should not be set in normal flows.
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
data = super().model_dump(*args, **kwargs)
|
||||
@@ -332,33 +298,9 @@ class DocumentSearchRequest(ChunkContext):
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
# Easier APIs to work with for developers
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
persona_id: int
|
||||
messages: list[ThreadMessage]
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generating an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
elif self.persona_override_config is not None and (self.persona_id is not None):
|
||||
raise ValueError(
|
||||
"If persona_override_config is set, persona_id cannot be set"
|
||||
)
|
||||
return self
|
||||
filters: BaseFilters | None = None
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
@@ -41,7 +40,6 @@ from onyx.db.chat import get_valid_messages_from_query_sessions
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -212,22 +210,12 @@ def get_answer_stream(
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for Answer API: {query}")
|
||||
|
||||
if (
|
||||
query_request.persona_override_config is None
|
||||
and query_request.persona_id is None
|
||||
):
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
persona_info: Persona | PersonaOverrideConfig | None = None
|
||||
if query_request.persona_override_config is not None:
|
||||
persona_info = query_request.persona_override_config
|
||||
elif query_request.persona_id is not None:
|
||||
persona_info = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
persona_info = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
|
||||
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
|
||||
|
||||
@@ -249,15 +237,11 @@ def get_answer_stream(
|
||||
# Also creates a new chat session
|
||||
request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
filters=query_request.filters,
|
||||
user=user,
|
||||
persona_id=query_request.persona_id,
|
||||
persona_override_config=query_request.persona_override_config,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -166,3 +167,56 @@ class ToolCallInfo(BaseModel):
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
||||
|
||||
class BaseCiteableToolResult(BaseModel):
|
||||
"""Base class for tool results that can be cited."""
|
||||
|
||||
document_citation_number: int
|
||||
unique_identifier_to_strip_away: str | None = None
|
||||
type: str
|
||||
|
||||
|
||||
class LlmInternalSearchResult(BaseCiteableToolResult):
|
||||
"""Result from an internal search query"""
|
||||
|
||||
type: Literal["internal_search"] = "internal_search"
|
||||
title: str
|
||||
excerpt: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class LlmWebSearchResult(BaseCiteableToolResult):
|
||||
"""Result from a web search query"""
|
||||
|
||||
type: Literal["web_search"] = "web_search"
|
||||
url: str
|
||||
title: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class LlmOpenUrlResult(BaseCiteableToolResult):
|
||||
"""Result from opening/fetching a URL"""
|
||||
|
||||
type: Literal["open_url"] = "open_url"
|
||||
content: str
|
||||
|
||||
|
||||
class PythonExecutionFile(BaseModel):
|
||||
"""File generated during Python execution"""
|
||||
|
||||
filename: str
|
||||
file_link: str
|
||||
|
||||
|
||||
class LlmPythonExecutionResult(BaseModel):
|
||||
"""Result from Python code execution"""
|
||||
|
||||
type: Literal["python_execution"] = "python_execution"
|
||||
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
generated_files: list[PythonExecutionFile]
|
||||
error: str | None = None
|
||||
|
||||
@@ -17,17 +17,17 @@ from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.tools.models import (
|
||||
LlmPythonExecutionResult,
|
||||
)
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmPythonExecutionResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import PythonExecutionFile
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Base models for tool results with citation support."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseCiteableToolResult(BaseModel):
|
||||
"""Base class for tool results that can be cited."""
|
||||
|
||||
document_citation_number: int
|
||||
unique_identifier_to_strip_away: str | None = None
|
||||
type: str
|
||||
|
||||
|
||||
class LlmInternalSearchResult(BaseCiteableToolResult):
|
||||
"""Result from an internal search query"""
|
||||
|
||||
type: Literal["internal_search"] = "internal_search"
|
||||
title: str
|
||||
excerpt: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class LlmWebSearchResult(BaseCiteableToolResult):
|
||||
"""Result from a web search query"""
|
||||
|
||||
type: Literal["web_search"] = "web_search"
|
||||
url: str
|
||||
title: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class LlmOpenUrlResult(BaseCiteableToolResult):
|
||||
"""Result from opening/fetching a URL"""
|
||||
|
||||
type: Literal["open_url"] = "open_url"
|
||||
content: str
|
||||
|
||||
|
||||
class PythonExecutionFile(BaseModel):
|
||||
"""File generated during Python execution"""
|
||||
|
||||
filename: str
|
||||
file_link: str
|
||||
|
||||
|
||||
class LlmPythonExecutionResult(BaseModel):
|
||||
"""Result from Python code execution"""
|
||||
|
||||
type: Literal["python_execution"] = "python_execution"
|
||||
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
generated_files: list[PythonExecutionFile]
|
||||
error: str | None = None
|
||||
@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
@@ -73,9 +72,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
message="hello",
|
||||
file_descriptors=[],
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
)
|
||||
|
||||
response_stream: list[AnswerStreamPart] = []
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -45,10 +44,7 @@ def test_stream_chat_current_date_response(
|
||||
parent_message_id=None,
|
||||
message="Please respond only with the current date in the format 'Weekday Month DD, YYYY'.",
|
||||
file_descriptors=[],
|
||||
prompt_override=None,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
query_override=None,
|
||||
filters=None,
|
||||
)
|
||||
|
||||
gen = stream_chat_message_objects(
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
@@ -104,11 +103,6 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
message="run a web search for 'Onyx'",
|
||||
file_descriptors=[],
|
||||
prompt_override=None,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
query_override=None,
|
||||
)
|
||||
# Call stream_chat_message_objects
|
||||
response_generator = stream_chat_message_objects(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,12 +8,10 @@ from uuid import UUID
|
||||
import requests
|
||||
from requests.models import Response
|
||||
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
@@ -92,14 +90,8 @@ class ChatSessionManager:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
search_doc_ids: list[int] | None = None,
|
||||
retrieval_options: RetrievalDetails | None = None,
|
||||
query_override: str | None = None,
|
||||
regenerate: bool | None = None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
use_agentic_search: bool = False,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
chat_session: DATestChatSession | None = None,
|
||||
) -> StreamedResponse:
|
||||
@@ -109,15 +101,8 @@ class ChatSessionManager:
|
||||
message=message,
|
||||
file_descriptors=file_descriptors or [],
|
||||
search_doc_ids=search_doc_ids or [],
|
||||
retrieval_options=retrieval_options,
|
||||
rerank_settings=None, # Can be added if needed
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
use_agentic_search=use_agentic_search,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
|
||||
# Create a thread to use in the tests
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads", # Updated endpoint path
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return UUID(response.json()["id"])
|
||||
@@ -1,151 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
|
||||
|
||||
|
||||
def test_create_assistant(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"name": "Test Assistant",
|
||||
"description": "A test assistant",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Assistant"
|
||||
assert data["description"] == "A test assistant"
|
||||
assert data["model"] == "gpt-3.5-turbo"
|
||||
assert data["instructions"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the assistant
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Retrieve Test"
|
||||
|
||||
|
||||
def test_modify_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the assistant
|
||||
response = requests.post(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
json={"name": "Modified Assistant", "instructions": "New instructions"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Modified Assistant"
|
||||
assert data["instructions"] == "New instructions"
|
||||
|
||||
|
||||
def test_delete_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the assistant
|
||||
response = requests.delete(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["deleted"] is True
|
||||
|
||||
|
||||
def test_list_assistants(admin_user: DATestUser | None) -> None:
|
||||
# Create multiple assistants
|
||||
for i in range(3):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the assistants
|
||||
response = requests.get(
|
||||
ASSISTANTS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) >= 3 # At least the 3 we just created
|
||||
assert all(assistant["object"] == "assistant" for assistant in data["data"])
|
||||
|
||||
|
||||
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
|
||||
# Create 5 assistants
|
||||
for i in range(5):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# List assistants with limit
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
# Get next page
|
||||
before = data["last_id"]
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2&before={before}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
|
||||
non_existent_id = -99
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{non_existent_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
@@ -1,133 +0,0 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> str:
|
||||
response = requests.post(
|
||||
BASE_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello, world!",
|
||||
"file_ids": [],
|
||||
"metadata": {"key": "value"},
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the messages
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["object"] == "list"
|
||||
assert isinstance(response_json["data"], list)
|
||||
assert len(response_json["data"]) > 0
|
||||
assert "first_id" in response_json
|
||||
assert "last_id" in response_json
|
||||
assert "has_more" in response_json
|
||||
|
||||
|
||||
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
|
||||
|
||||
|
||||
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the message
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
|
||||
non_existent_thread_id = str(uuid.uuid4())
|
||||
non_existent_message_id = -99
|
||||
|
||||
# Test with non-existent thread
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test with non-existent message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
@@ -1,137 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
|
||||
"""Create a run and return its ID."""
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_run(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"instructions": "Test instructions",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["assistant_id"] == 0
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
assert response_json["status"] == "queued"
|
||||
assert response_json["model"] == "gpt-3.5-turbo"
|
||||
assert response_json["instructions"] == "Test instructions"
|
||||
|
||||
|
||||
def test_retrieve_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
retrieve_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
|
||||
|
||||
def test_cancel_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
cancel_response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert cancel_response.status_code == 200
|
||||
|
||||
response_json = cancel_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_list_runs(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
# Create a few runs
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the runs
|
||||
list_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
assert len(response_json) >= 3
|
||||
|
||||
for run in response_json:
|
||||
assert "id" in run
|
||||
assert run["object"] == "thread.run"
|
||||
assert "created_at" in run
|
||||
assert UUID(run["thread_id"]) == thread_id
|
||||
assert "status" in run
|
||||
assert "model" in run
|
||||
|
||||
|
||||
def test_list_run_steps(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
steps_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert steps_response.status_code == 200
|
||||
|
||||
response_json = steps_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
|
||||
assert len(response_json) == 0
|
||||
@@ -1,131 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
def test_create_thread(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
THREADS_URL,
|
||||
json={"messages": None, "metadata": {"key": "value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the thread
|
||||
retrieve_response = requests.get(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
|
||||
|
||||
def test_modify_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the thread
|
||||
modify_response = requests.post(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert modify_response.status_code == 200
|
||||
|
||||
response_json = modify_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_delete_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the thread
|
||||
delete_response = requests.delete(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
response_json = delete_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread.deleted"
|
||||
assert response_json["deleted"] is True
|
||||
|
||||
|
||||
def test_list_threads(admin_user: DATestUser | None) -> None:
|
||||
# Create a few threads
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the threads
|
||||
list_response = requests.get(
|
||||
THREADS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert "sessions" in response_json
|
||||
assert len(response_json["sessions"]) >= 3
|
||||
|
||||
for session in response_json["sessions"]:
|
||||
assert "id" in session
|
||||
assert "name" in session
|
||||
assert "persona_id" in session
|
||||
assert "time_created" in session
|
||||
assert "shared_status" in session
|
||||
assert "current_alternate_model" in session
|
||||
|
||||
# Validate UUID
|
||||
UUID(session["id"])
|
||||
|
||||
# Validate shared_status
|
||||
assert session["shared_status"] in [
|
||||
status.value for status in ChatSessionSharedStatus
|
||||
]
|
||||
@@ -155,7 +155,6 @@ def test_soft_delete_with_agentic_search(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="What are the key principles of software engineering?",
|
||||
user_performing_action=basic_user,
|
||||
use_agentic_search=True,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
@@ -206,7 +205,6 @@ def test_hard_delete_with_agentic_search(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="What are the key principles of software engineering?",
|
||||
user_performing_action=basic_user,
|
||||
use_agentic_search=True,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
|
||||
@@ -5,9 +5,7 @@ from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.query_and_chat.models import OneShotQARequest
|
||||
@@ -24,9 +22,7 @@ def _api_url_builder(env_name: str, api_path: str) -> str:
|
||||
|
||||
|
||||
@retry(tries=5, delay=5)
|
||||
def get_answer_from_query(
|
||||
query: str, only_retrieve_docs: bool, env_name: str
|
||||
) -> tuple[list[str], str]:
|
||||
def get_answer_from_query(query: str, env_name: str) -> tuple[list[str], str]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
@@ -40,13 +36,7 @@ def get_answer_from_query(
|
||||
new_message_request = OneShotQARequest(
|
||||
messages=messages,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
skip_gen_ai_answer_generation=only_retrieve_docs,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
url = _api_url_builder(env_name, "/query/answer-with-citation/")
|
||||
|
||||
@@ -116,7 +116,6 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
|
||||
query = question_data["question"]
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
only_retrieve_docs=config["only_retrieve_docs"],
|
||||
env_name=config["env_name"],
|
||||
)
|
||||
print(f"On question number {question_number}")
|
||||
|
||||
@@ -45,9 +45,7 @@ from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
@@ -432,14 +430,7 @@ class SearchAnswerAnalyzer:
|
||||
qa_request = OneShotQARequest(
|
||||
messages=messages,
|
||||
persona_id=0, # default persona
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
limit=self.config.max_search_results,
|
||||
),
|
||||
skip_gen_ai_answer_generation=self.config.search_only,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
# send the request
|
||||
|
||||
@@ -130,16 +130,9 @@ async function* sendMessage({
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId || null,
|
||||
message: message,
|
||||
prompt_id: null,
|
||||
search_doc_ids: null,
|
||||
file_descriptors: [],
|
||||
// checkout https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/search/models.py#L105 for
|
||||
// all available options
|
||||
retrieval_options: {
|
||||
run_search: "always",
|
||||
filters: null,
|
||||
},
|
||||
query_override: null,
|
||||
filters: null,
|
||||
}),
|
||||
});
|
||||
if (!sendMessageResponse.ok) {
|
||||
|
||||
@@ -174,8 +174,6 @@ export default function ChatPage({
|
||||
if (message) {
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -603,8 +601,6 @@ export default function ChatPage({
|
||||
// We call onSubmit, passing a `messageOverride`
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}
|
||||
@@ -632,11 +628,9 @@ export default function ChatPage({
|
||||
const handleChatInputSubmit = useCallback(() => {
|
||||
onSubmit({
|
||||
message: message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
});
|
||||
setShowOnboarding(false);
|
||||
}, [message, onSubmit, currentMessageFiles, deepResearchEnabled]);
|
||||
}, [message, onSubmit]);
|
||||
|
||||
// Memoized callbacks for DocumentResults
|
||||
const handleMobileDocumentSidebarClose = useCallback(() => {
|
||||
|
||||
@@ -22,16 +22,11 @@ interface MessagesDisplayProps {
|
||||
onSubmit: (args: {
|
||||
message: string;
|
||||
messageIdToResend?: number;
|
||||
currentMessageFiles: ProjectFile[];
|
||||
useAgentSearch: boolean;
|
||||
modelOverride?: LlmDescriptor;
|
||||
regenerationRequest?: {
|
||||
messageId: number;
|
||||
parentMessage: Message;
|
||||
forceSearch?: boolean;
|
||||
};
|
||||
forceSearch?: boolean;
|
||||
queryOverride?: string;
|
||||
isSeededChat?: boolean;
|
||||
overrideFileDescriptors?: FileDescriptor[];
|
||||
}) => Promise<void>;
|
||||
@@ -77,24 +72,17 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
|
||||
const emptyDocs = useMemo<OnyxDocument[]>(() => [], []);
|
||||
const emptyChildrenIds = useMemo<number[]>(() => [], []);
|
||||
const createRegenerator = useCallback(
|
||||
(regenerationRequest: {
|
||||
messageId: number;
|
||||
parentMessage: Message;
|
||||
forceSearch?: boolean;
|
||||
}) => {
|
||||
(regenerationRequest: { messageId: number; parentMessage: Message }) => {
|
||||
return async function (modelOverride: LlmDescriptor) {
|
||||
return await onSubmit({
|
||||
message: regenerationRequest.parentMessage.message,
|
||||
currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
modelOverride,
|
||||
messageIdToResend: regenerationRequest.parentMessage.messageId,
|
||||
regenerationRequest,
|
||||
forceSearch: regenerationRequest.forceSearch,
|
||||
});
|
||||
};
|
||||
},
|
||||
[onSubmit, deepResearchEnabled, currentMessageFiles]
|
||||
[onSubmit]
|
||||
);
|
||||
|
||||
const handleEditWithMessageId = useCallback(
|
||||
@@ -102,11 +90,9 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
|
||||
onSubmit({
|
||||
message: editedContent,
|
||||
messageIdToResend: msgId,
|
||||
currentMessageFiles: [],
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
});
|
||||
},
|
||||
[onSubmit, deepResearchEnabled]
|
||||
[onSubmit]
|
||||
);
|
||||
|
||||
// require assistant to be present before rendering
|
||||
|
||||
@@ -81,16 +81,8 @@ const SYSTEM_MESSAGE_ID = -3;
|
||||
|
||||
export interface OnSubmitProps {
|
||||
message: string;
|
||||
//from chat input bar
|
||||
currentMessageFiles: ProjectFile[];
|
||||
// from the chat bar???
|
||||
|
||||
useAgentSearch: boolean;
|
||||
|
||||
// optional params
|
||||
messageIdToResend?: number;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
isSeededChat?: boolean;
|
||||
modelOverride?: LlmDescriptor;
|
||||
regenerationRequest?: RegenerationRequest | null;
|
||||
@@ -100,7 +92,6 @@ export interface OnSubmitProps {
|
||||
interface RegenerationRequest {
|
||||
messageId: number;
|
||||
parentMessage: Message;
|
||||
forceSearch?: boolean;
|
||||
}
|
||||
|
||||
interface UseChatControllerProps {
|
||||
@@ -142,8 +133,13 @@ export function useChatController({
|
||||
const { refreshChatSessions } = useChatSessions();
|
||||
const { assistantPreferences } = useAssistantPreferences();
|
||||
const { forcedToolIds } = useForcedTools();
|
||||
const { fetchProjects, uploadFiles, setCurrentMessageFiles, beginUpload } =
|
||||
useProjectsContext();
|
||||
const {
|
||||
fetchProjects,
|
||||
uploadFiles,
|
||||
setCurrentMessageFiles,
|
||||
beginUpload,
|
||||
currentMessageFiles,
|
||||
} = useProjectsContext();
|
||||
const posthog = usePostHog();
|
||||
|
||||
// Use selectors to access only the specific fields we need
|
||||
@@ -387,11 +383,7 @@ export function useChatController({
|
||||
const onSubmit = useCallback(
|
||||
async ({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
useAgentSearch,
|
||||
messageIdToResend,
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
isSeededChat,
|
||||
modelOverride,
|
||||
regenerationRequest,
|
||||
@@ -663,7 +655,6 @@ export function useChatController({
|
||||
updateCurrentMessageFIFO(stack, {
|
||||
signal: controller.signal,
|
||||
message: currMessage,
|
||||
alternateAssistantId: liveAssistant?.id,
|
||||
fileDescriptors: overrideFileDescriptors,
|
||||
parentMessageId: (() => {
|
||||
const parentId =
|
||||
@@ -687,15 +678,6 @@ export function useChatController({
|
||||
document.db_doc_id !== undefined && document.db_doc_id !== null
|
||||
)
|
||||
.map((document) => document.db_doc_id as number),
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
currentMessageFiles: currentMessageFiles.map((file) => ({
|
||||
id: file.file_id,
|
||||
type: file.chat_file_type,
|
||||
name: file.name,
|
||||
user_file_id: file.id,
|
||||
})),
|
||||
regenerate: regenerationRequest !== undefined,
|
||||
modelProvider:
|
||||
modelOverride?.name || llmManager.currentLlm.name || undefined,
|
||||
modelVersion:
|
||||
@@ -704,10 +686,7 @@ export function useChatController({
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
temperature: llmManager.temperature || undefined,
|
||||
systemPromptOverride:
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
useAgentSearch,
|
||||
enabledToolIds:
|
||||
disabledToolIds && liveAssistant
|
||||
? liveAssistant.tools
|
||||
|
||||
@@ -57,8 +57,6 @@ interface UseChatSessionControllerProps {
|
||||
refreshChatSessions: () => void;
|
||||
onSubmit: (params: {
|
||||
message: string;
|
||||
currentMessageFiles: ProjectFile[];
|
||||
useAgentSearch: boolean;
|
||||
isSeededChat?: boolean;
|
||||
}) => Promise<void>;
|
||||
}
|
||||
@@ -170,8 +168,6 @@ export function useChatSessionController({
|
||||
submitOnLoadPerformed.current = true;
|
||||
await onSubmit({
|
||||
message: firstMessage || "",
|
||||
currentMessageFiles: [],
|
||||
useAgentSearch: false,
|
||||
});
|
||||
}
|
||||
return;
|
||||
@@ -293,8 +289,6 @@ export function useChatSessionController({
|
||||
await onSubmit({
|
||||
message: seededMessage,
|
||||
isSeededChat: true,
|
||||
currentMessageFiles: [],
|
||||
useAgentSearch: false,
|
||||
});
|
||||
// Force re-name if the chat session doesn't have one
|
||||
if (!chatSession.description) {
|
||||
|
||||
@@ -160,78 +160,45 @@ export type PacketType =
|
||||
| Packet;
|
||||
|
||||
export interface SendMessageParams {
|
||||
regenerate: boolean;
|
||||
message: string;
|
||||
fileDescriptors?: FileDescriptor[];
|
||||
parentMessageId: number | null;
|
||||
chatSessionId: string;
|
||||
filters: Filters | null;
|
||||
selectedDocumentIds: number[] | null;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
systemPromptOverride?: string;
|
||||
useExistingUserMessage?: boolean;
|
||||
alternateAssistantId?: number;
|
||||
signal?: AbortSignal;
|
||||
currentMessageFiles?: FileDescriptor[];
|
||||
useAgentSearch?: boolean;
|
||||
enabledToolIds?: number[];
|
||||
forcedToolIds?: number[];
|
||||
}
|
||||
|
||||
export async function* sendMessage({
|
||||
regenerate,
|
||||
message,
|
||||
fileDescriptors,
|
||||
currentMessageFiles,
|
||||
parentMessageId,
|
||||
chatSessionId,
|
||||
filters,
|
||||
selectedDocumentIds,
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
modelProvider,
|
||||
modelVersion,
|
||||
temperature,
|
||||
systemPromptOverride,
|
||||
useExistingUserMessage,
|
||||
alternateAssistantId,
|
||||
signal,
|
||||
useAgentSearch,
|
||||
enabledToolIds,
|
||||
forcedToolIds,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
const payload = {
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId,
|
||||
message: message,
|
||||
// just use the default prompt for the assistant.
|
||||
// should remove this in the future, as we don't support multiple prompts for a
|
||||
// single assistant anyways
|
||||
prompt_id: null,
|
||||
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
|
||||
file_descriptors: fileDescriptors,
|
||||
current_message_files: currentMessageFiles,
|
||||
regenerate,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search: queryOverride || forceSearch ? "always" : "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
filters: !documentsAreSelected ? filters : null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
@@ -241,7 +208,6 @@ export async function* sendMessage({
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
use_agentic_search: useAgentSearch ?? false,
|
||||
allowed_tool_ids: enabledToolIds,
|
||||
forced_tool_ids: forcedToolIds,
|
||||
};
|
||||
|
||||
@@ -22,8 +22,6 @@ export function Suggestions({ onSubmit }: SuggestionsProps) {
|
||||
const handleSuggestionClick = (suggestion: string) => {
|
||||
onSubmit({
|
||||
message: suggestion,
|
||||
currentMessageFiles: [],
|
||||
useAgentSearch: false,
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user