Compare commits

...

10 Commits

Author SHA1 Message Date
joachim-danswer
1c23cf574c Nits 2025-02-10 17:13:16 -08:00
joachim-danswer
0ccf78ac52 reused error strings & BaseMessage_Content 2025-02-10 16:57:25 -08:00
joachim-danswer
02b4b4bf0d remove execs 2025-02-10 16:21:59 -08:00
joachim-danswer
dade11a2e6 EL - OVERRIDE 2025-02-10 14:41:55 -08:00
joachim-danswer
188a5f0d62 EL comments
- overwrite -> override
 - enums for error types
 - some nits
2025-02-10 14:33:58 -08:00
joachim-danswer
89c0b1ad37 YS comments 2025-02-10 13:29:12 -08:00
pablodanswer
8b20fd31b6 quick update 2025-02-08 17:14:01 -08:00
pablodanswer
6a73245986 quick ux update 2025-02-07 23:40:26 -08:00
joachim-danswer
dd73fdcd08 timeout prep backend 2025-02-07 18:21:35 -08:00
joachim-danswer
768456609a Removal of defaults from various input states + removal of bas 2025-02-07 18:19:18 -08:00
30 changed files with 977 additions and 216 deletions

View File

@@ -9,7 +9,6 @@ class CoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
@@ -18,4 +17,4 @@ class SubgraphCoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]
log_messages: Annotated[list[str], add] = []

View File

@@ -1,8 +1,8 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
@@ -12,12 +12,39 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
def check_sub_answer(
@@ -53,14 +80,46 @@ def check_sub_answer(
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
)
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - check sub answer")
if agent_error:
answer_quality = True
log_result = agent_error.error_result
else:
if response:
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
else:
answer_quality = True
quality_str = "yes - because LLM error"
log_result = f"Answer quality: {quality_str}"
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
@@ -69,7 +128,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
result=log_result,
)
],
)

View File

@@ -16,6 +16,20 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -30,11 +44,20 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
def generate_sub_answer(
state: AnswerQuestionState,
@@ -57,6 +80,8 @@ def generate_sub_answer(
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
@@ -79,41 +104,67 @@ def generate_sub_answer(
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate sub answer")
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -131,7 +182,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
result=log_results or "",
)
],
)

View File

@@ -42,10 +42,8 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
question: str
question_id: str
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.

View File

@@ -26,7 +26,18 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
@@ -42,12 +53,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.context.search.models import InferenceSection
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
@@ -57,6 +72,12 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
def generate_initial_answer(
state: SubQuestionRetrievalState,
@@ -224,30 +245,82 @@ def generate_initial_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
return InitialAnswerUpdate(
initial_answer=None,
error=AgentErrorLoggingFormat(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"

View File

@@ -25,7 +25,7 @@ def validate_initial_answer(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
verdict = True # not actually required as already streamed out. Refinement will do similar
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,

View File

@@ -23,6 +23,18 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -33,6 +45,11 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
@@ -43,6 +60,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
def decompose_orig_question(
state: SubQuestionRetrievalState,
@@ -112,11 +135,35 @@ def decompose_orig_question(
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
except LLMTimeoutError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -125,19 +172,19 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
deomposition_response = merge_content(*streamed_tokens)
if agent_error:
initial_sub_questions: list[str] = []
log_result = agent_error.error_result
else:
deomposition_response = merge_content(*streamed_tokens)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
list_of_subqs = cast(str, deomposition_response).split("\n")
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
return InitialQuestionDecompositionUpdate(
initial_sub_questions=decomp_list,
initial_sub_questions=initial_sub_questions,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
@@ -151,7 +198,7 @@ def decompose_orig_question(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=f"decomposed original question into {len(decomp_list)} subquestions",
result=log_result,
)
],
)

View File

@@ -252,9 +252,7 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
inputs = MainInput(log_messages=[])
for thing in compiled_graph.stream(
input=inputs,

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -10,14 +11,37 @@ from onyx.agents.agent_search.deep_search.main.states import (
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
def compare_answers(
@@ -40,15 +64,46 @@ def compare_answers(
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLoggingFormat | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
resp = model.invoke(msg)
try:
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
log_result = f"Answer comparison: {refined_answer_improvement}"
write_custom_event(
"refined_answer_improvement",
@@ -65,7 +120,7 @@ def compare_answers(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=f"Answer comparison: {refined_answer_improvement}",
result=log_result,
)
],
)

View File

@@ -21,6 +21,18 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
@@ -30,10 +42,25 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
def create_refined_sub_questions(
@@ -96,29 +123,65 @@ def create_refined_sub_questions(
# Grader
model = graph_config.tooling.fast_llm
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
response = merge_content(*streamed_tokens)
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
else:
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
@@ -128,7 +191,7 @@ def create_refined_sub_questions(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
result=log_result,
)
],
)

View File

@@ -26,6 +26,19 @@ def decide_refinement_need(
decision = True # TODO: just for current testing purposes
if state.error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -21,6 +21,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
@@ -81,6 +84,7 @@ def extract_entities_terms(
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -11,7 +11,6 @@ from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
@@ -23,7 +22,18 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
@@ -43,8 +53,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
@@ -56,6 +72,15 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
def generate_refined_answer(
@@ -231,28 +256,80 @@ def generate_refined_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
agent_error: AgentErrorLoggingFormat | None = None
start_stream_token = datetime.now()
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
@@ -266,49 +343,6 @@ def generate_refined_answer(
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (

View File

@@ -17,6 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -76,6 +77,7 @@ class InitialAnswerUpdate(LoggerUpdate):
"""
initial_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
@@ -88,6 +90,7 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
refined_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False

View File

@@ -16,14 +16,40 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
def expand_queries(
@@ -54,13 +80,43 @@ def expand_queries(
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
agent_error: AgentErrorLoggingFormat | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
try:
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
rewritten_queries = llm_response.split("\n")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
# use subquestion as query if query generation fails
if agent_error:
llm_response = ""
rewritten_queries = [question]
log_result = agent_error.error_result
else:
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
@@ -69,7 +125,7 @@ def expand_queries(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=f"Number of expanded queries: {len(rewritten_queries)}",
result=log_result,
)
],
)

View File

@@ -1,5 +1,6 @@
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
@@ -10,12 +11,41 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
)
def verify_documents(
@@ -26,7 +56,7 @@ def verify_documents(
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing ProSearchConfig
config (RunnableConfig): Configuration containing AgentSearchConfig
Updates:
verified_documents: list[InferenceSection]
@@ -51,11 +81,42 @@ def verify_documents(
)
]
response = fast_llm.invoke(msg)
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
try:
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - verify documents")
except LLMRateLimitError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - verify documents")
if agent_error or response is None:
verified_documents = [retrieved_document_to_verify]
else:
verified_documents = []
if isinstance(response.content, str) and binary_string_test(
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
):
verified_documents.append(retrieved_document_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,

View File

@@ -21,9 +21,13 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalInput(SubgraphCoreState):
question: str = ""
base_search: bool = False
# exception from 'no default value'for LangGraph input states
# Here, sub_question_id default Nonoe implies usage for the
# original question. This is sometimes needed for nested sub-graphs
sub_question_id: str | None = None
question: str
base_search: bool
## Update/Return States
@@ -88,4 +92,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str = ""
query_to_retrieve: str

View File

@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
MainInput as MainInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -21,6 +21,7 @@ from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
@@ -33,6 +34,7 @@ from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
@@ -72,13 +74,15 @@ def _parse_agent_event(
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
elif event["name"] == "refined_sub_question_creation_error":
return cast(StreamingError, event["data"])
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput_a,
graph_input: BasicInput | MainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -92,7 +96,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput_a,
input: BasicInput | MainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -123,9 +127,7 @@ def run_main_graph(
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
input = MainInput(log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@@ -172,9 +174,7 @@ if __name__ == "__main__":
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
input = MainInput(log_messages=[])
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):

View File

@@ -150,3 +150,17 @@ def get_prompt_enrichment_components(
history=history,
date_str=date_str,
)
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()

View File

@@ -0,0 +1,17 @@
from enum import Enum
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
AGENT_LLM_RATELIMIT_MESSAGE = (
"The agent encountered a rate limit error. Please try again."
)
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
AGENT_POSITIVE_VALUE_STR = "yes"
AGENT_NEGATIVE_VALUE_STR = "no"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"
GENERAL_ERROR = "general_error"

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
@@ -56,6 +58,12 @@ class InitialAgentResultStats(BaseModel):
agent_effectiveness: dict[str, float | int | None]
class AgentErrorLoggingFormat(BaseModel):
error_message: str
error_type: str
error_result: str | None = None
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
@@ -126,3 +134,12 @@ class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str
class LLMNodeErrorStrings(BaseModel):
timeout: str = "LLM Timeout Error"
rate_limit: str = "LLM Rate Limit Error"
general_error: str = "General LLM Error"
BaseMessage_Content = str | list[str | dict[str, Any]]

View File

@@ -20,6 +20,7 @@ from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -34,6 +35,9 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -46,6 +50,8 @@ from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
@@ -65,8 +71,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
BaseMessage_Content = str | list[str | dict[str, Any]]
logger = setup_logger()
# Post-processing
@@ -372,8 +379,24 @@ def summarize_history(
)
)
history_response = llm.invoke(history_context_prompt)
try:
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
assert isinstance(history_response.content, str)
return history_response.content

View File

@@ -13,6 +13,21 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
#####
# Agent Configs
#####
@@ -77,4 +92,76 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -50,6 +50,18 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
class LLMTimeoutError(Exception):
"""
Exception raised when an LLM call times out.
"""
class LLMRateLimitError(Exception):
"""
Exception raised when an LLM call is rate limited.
"""
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
return "user"
@@ -380,6 +392,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -405,7 +418,7 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=0,
timeout=self._timeout,
timeout=timeout_override or self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -424,6 +437,12 @@ class DefaultMultiLLM(LLM):
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
if isinstance(e, litellm.Timeout):
raise LLMTimeoutError(e)
elif isinstance(e, litellm.RateLimitError):
raise LLMRateLimitError(e)
raise e
@property
@@ -444,6 +463,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -451,7 +471,12 @@ class DefaultMultiLLM(LLM):
response = cast(
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
),
)
choice = response.choices[0]
@@ -469,19 +494,31 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
yield self.invoke(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(
prompt, tools, tool_choice, True, structured_response_format
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
),
)
try:

View File

@@ -81,6 +81,7 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -90,5 +91,6 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -90,12 +90,13 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
prompt, tools, tool_choice, structured_response_format, timeout_override
)
@abc.abstractmethod
@@ -105,6 +106,7 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -114,12 +116,13 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
prompt, tools, tool_choice, structured_response_format, timeout_override
)
tokens = []
@@ -138,5 +141,6 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -5,8 +5,6 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
NO_RECOVERED_DOCS = "No relevant information recovered"
YES = "yes"
NO = "no"
# Framing/Support/Template Prompts
HISTORY_FRAMING_PROMPT = f"""
For more context, here is the history of the conversation so far that preceded this question:

View File

@@ -1121,6 +1121,7 @@ export function ChatPage({
"Continue Generating (pick up exactly where you left off)",
});
};
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
const onSubmit = async ({
messageIdToResend,
@@ -1549,8 +1550,23 @@ export function ChatPage({
}
);
} else if (Object.hasOwn(packet, "error")) {
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
if (
sub_questions.length > 0 &&
sub_questions
.filter((q) => q.level === 0)
.every((q) => q.is_stopped === true)
) {
setUncaughtError((packet as StreamingError).error);
updateChatState("input");
setAgenticGenerating(false);
setAlternativeGeneratingAssistant(null);
setSubmittedMessage("");
return;
// throw new Error((packet as StreamingError).error);
} else {
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
}
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
} else if (Object.hasOwn(packet, "stop_reason")) {
@@ -2039,6 +2055,7 @@ export function ChatPage({
}
const data = await response.json();
router.push(data.redirect_url);
} catch (error) {
console.error("Error seeding chat from Slack:", error);
@@ -2633,6 +2650,7 @@ export function ChatPage({
{message.sub_questions &&
message.sub_questions.length > 0 ? (
<AgenticMessage
error={uncaughtError}
docSidebarToggled={
documentSidebarToggled &&
(selectedMessageForDocDisplay ==

View File

@@ -80,6 +80,7 @@ export const AgenticMessage = ({
agenticDocs,
secondLevelSubquestions,
toggleDocDisplay,
error,
}: {
docSidebarToggled?: boolean;
isImprovement?: boolean | null;
@@ -110,6 +111,7 @@ export const AgenticMessage = ({
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void;
toggleDocDisplay?: (agentic: boolean) => void;
error?: string | null;
}) => {
const [noShowingMessage, setNoShowingMessage] = useState(isComplete);
@@ -483,11 +485,28 @@ export const AgenticMessage = ({
) : (
content
)}
{error && (
<p className="mt-2 text-red-700 text-sm my-auto">
{error}
</p>
)}
</div>
</div>
</>
) : isComplete ? null : (
<></>
) : isComplete ? (
error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
</p>
)
) : (
<>
{error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
</p>
)}
</>
)}
{handleFeedback &&
(isActive ? (

View File

@@ -185,6 +185,7 @@ export const AIMessage = ({
setPresentingDocument,
index,
toggledDocumentSidebar,
removePadding,
}: {
index?: number;
shared?: boolean;
@@ -213,6 +214,7 @@ export const AIMessage = ({
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void;
removePadding?: boolean;
}) => {
const toolCallGenerating = toolCall && !toolCall.tool_result;
@@ -398,7 +400,9 @@ export const AIMessage = ({
<div
id={isComplete ? "onyx-ai-message" : undefined}
ref={trackedElementRef}
className={`py-5 ml-4 lg:px-5 relative flex `}
className={`py-5 ml-4 lg:px-5 relative flex
${removePadding && "!pl-24 -mt-12"}`}
>
<div
className={`mx-auto ${
@@ -407,11 +411,13 @@ export const AIMessage = ({
>
<div className={`lg:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
<div className="flex">
<AssistantIcon
className="mobile:hidden"
size={24}
assistant={alternativeAssistant || currentPersona}
/>
{!removePadding && (
<AssistantIcon
className="mobile:hidden"
size={24}
assistant={alternativeAssistant || currentPersona}
/>
)}
<div className="w-full">
<div className="max-w-message-max break-words">
@@ -588,7 +594,8 @@ export const AIMessage = ({
)}
</div>
{handleFeedback &&
{!removePadding &&
handleFeedback &&
(isActive ? (
<div
className={`