mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 16:02:45 +00:00
Compare commits
159 Commits
cli/v0.2.1
...
dr-merge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16a4149a05 | ||
|
|
fe4e1b75fb | ||
|
|
06f0224622 | ||
|
|
d512912fcd | ||
|
|
30991db439 | ||
|
|
5f07f189bf | ||
|
|
c3cb351231 | ||
|
|
088156f0bf | ||
|
|
86c0bfec1c | ||
|
|
f5d1319ffa | ||
|
|
8e55e1c2ec | ||
|
|
5678ceeae9 | ||
|
|
82f18bec66 | ||
|
|
55b314fba6 | ||
|
|
ac062f43a5 | ||
|
|
9f783e2762 | ||
|
|
55c26f5f34 | ||
|
|
050524d8b0 | ||
|
|
0930712fdf | ||
|
|
486c820c46 | ||
|
|
9a568a10e1 | ||
|
|
91b91311d5 | ||
|
|
5dc5bf4ff5 | ||
|
|
f7fc1c827d | ||
|
|
e721618e95 | ||
|
|
04897e222e | ||
|
|
09ce320f7f | ||
|
|
99f79804bb | ||
|
|
062c9a4b73 | ||
|
|
fdd83eeaa9 | ||
|
|
f4828cfc18 | ||
|
|
f526171ca4 | ||
|
|
8b8802f9be | ||
|
|
dc3ca66f6d | ||
|
|
3a9f1accaf | ||
|
|
d264e880f2 | ||
|
|
9e0a22d866 | ||
|
|
bcbb075b96 | ||
|
|
5526e2b34b | ||
|
|
ee1541061c | ||
|
|
c1446d1508 | ||
|
|
acf9f615b1 | ||
|
|
ccde845e47 | ||
|
|
cad3517f85 | ||
|
|
e71489b2ff | ||
|
|
bfb6d632d2 | ||
|
|
191577fa19 | ||
|
|
a190934193 | ||
|
|
a7d140cb5d | ||
|
|
5e743515e9 | ||
|
|
4ef7e44c95 | ||
|
|
91bc1e93ba | ||
|
|
e7bd58cc85 | ||
|
|
dd18291d51 | ||
|
|
9a5ea03cd1 | ||
|
|
7b37e72b9d | ||
|
|
09d672ff22 | ||
|
|
b028b25737 | ||
|
|
07768d5484 | ||
|
|
5ca8ca2b1e | ||
|
|
62872e58ae | ||
|
|
eee3054b45 | ||
|
|
5f66a27c67 | ||
|
|
c21fa21958 | ||
|
|
cd6577c3ca | ||
|
|
16406f0ebd | ||
|
|
4ae5bb1e6b | ||
|
|
b0c95ec876 | ||
|
|
397d30c802 | ||
|
|
f13b08b461 | ||
|
|
e66245ec13 | ||
|
|
c64c6368c1 | ||
|
|
b2fe55c8f8 | ||
|
|
2b661441d7 | ||
|
|
f83f06228b | ||
|
|
fabfa8d166 | ||
|
|
994e7f7666 | ||
|
|
c81a7e1ef2 | ||
|
|
1d7d2f06d8 | ||
|
|
916d6cb119 | ||
|
|
6d3542ded1 | ||
|
|
e5dbfc34c3 | ||
|
|
1aad7f44d2 | ||
|
|
a0d6d0b922 | ||
|
|
588023a1f6 | ||
|
|
e4c2427728 | ||
|
|
bf77da26fc | ||
|
|
abfecde097 | ||
|
|
3f4936ad0a | ||
|
|
3b8d16a136 | ||
|
|
322e8668da | ||
|
|
d1dcad60d6 | ||
|
|
7b3bdbdf83 | ||
|
|
8b09fb0cef | ||
|
|
a2dd1bbf4f | ||
|
|
828231815a | ||
|
|
d48cbc2b79 | ||
|
|
991bd4f8bf | ||
|
|
74418b84a2 | ||
|
|
df1c40c791 | ||
|
|
c253844500 | ||
|
|
e972fb3e07 | ||
|
|
726211c27d | ||
|
|
c0435ddfd6 | ||
|
|
48dc934c35 | ||
|
|
3a575a92d5 | ||
|
|
de4a9e4687 | ||
|
|
c330152417 | ||
|
|
dca39f27a6 | ||
|
|
d3cc27846a | ||
|
|
fedc665b88 | ||
|
|
614672f357 | ||
|
|
6aca9ee005 | ||
|
|
f9f64fb1a5 | ||
|
|
4a63e631cd | ||
|
|
3d5586d623 | ||
|
|
6c4eb17b5d | ||
|
|
0917d9acd3 | ||
|
|
89ea0f8d48 | ||
|
|
31ae6f1eb1 | ||
|
|
1b8d246afb | ||
|
|
05e55559d8 | ||
|
|
241b8d062c | ||
|
|
6359d2f2d6 | ||
|
|
83325f9012 | ||
|
|
0b26ed602d | ||
|
|
2b69d1ba52 | ||
|
|
27cd1d44dc | ||
|
|
b5ddf31742 | ||
|
|
ce1c80148b | ||
|
|
bb95c46015 | ||
|
|
e8a593c315 | ||
|
|
bb1b12988c | ||
|
|
72bbcabedf | ||
|
|
2ee98ba795 | ||
|
|
594bbdb167 | ||
|
|
d5c67b6f50 | ||
|
|
9c7638ceba | ||
|
|
b1488ddccc | ||
|
|
d9a9818b9a | ||
|
|
4bd3b8b0bb | ||
|
|
da3979fc41 | ||
|
|
ffed8b4300 | ||
|
|
5eea47cb1c | ||
|
|
c830364c15 | ||
|
|
04f3ba1f3d | ||
|
|
84f76fbee7 | ||
|
|
00aeb3b280 | ||
|
|
8c30085a9e | ||
|
|
419e82f9f4 | ||
|
|
8330e5d8f4 | ||
|
|
e06c60a1a3 | ||
|
|
e7eef67893 | ||
|
|
b5209edffa | ||
|
|
07ad4dc022 | ||
|
|
06e1a2c1a5 | ||
|
|
083c152878 | ||
|
|
06f11a0a06 | ||
|
|
fabfcddadb |
@@ -0,0 +1,91 @@
|
||||
"""add research agent database tables and chat message research fields
|
||||
|
||||
Revision ID: 5ae8240accb3
|
||||
Revises: 62c3a055a141
|
||||
Create Date: 2025-08-06 14:29:24.691388
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5ae8240accb3"
|
||||
down_revision = "62c3a055a141"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_type and research_plan columns to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_type", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration table
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration_sub_step table
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"sub_step_tool_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("tool.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order
|
||||
op.drop_table("research_agent_iteration_sub_step")
|
||||
op.drop_table("research_agent_iteration")
|
||||
|
||||
# Remove columns from chat_message table
|
||||
op.drop_column("chat_message", "research_plan")
|
||||
op.drop_column("chat_message", "research_type")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add research_answer_purpose to chat_message
|
||||
|
||||
Revision ID: f8a9b2c3d4e5
|
||||
Revises: 5ae8240accb3
|
||||
Create Date: 2025-01-27 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f8a9b2c3d4e5"
|
||||
down_revision = "5ae8240accb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_answer_purpose column to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_answer_purpose", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove research_answer_purpose column from chat_message table
|
||||
op.drop_column("chat_message", "research_answer_purpose")
|
||||
@@ -29,7 +29,6 @@ from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
@@ -48,6 +47,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -6,10 +6,8 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -19,6 +17,8 @@ from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
|
||||
12
backend/onyx/agents/agent_search/basic/models.py
Normal file
12
backend/onyx/agents/agent_search/basic/models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
@@ -6,6 +6,9 @@ from pydantic import BaseModel
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
@@ -18,11 +21,15 @@ class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
unused: bool = True
|
||||
query_override: str | None = None
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
full_answer: str | None
|
||||
cited_references: list[InferenceSection] | None
|
||||
retrieved_documents: list[LlmDoc] | None
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
@@ -5,7 +5,9 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
@@ -13,6 +15,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -22,9 +27,12 @@ def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
@@ -37,6 +45,7 @@ def process_llm_stream(
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
@@ -54,11 +63,53 @@ def process_llm_stream(
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(response_part, "answer_piece")
|
||||
and generate_final_answer
|
||||
and response_part.answer_piece
|
||||
):
|
||||
if chat_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_message_id is required when generating final answer"
|
||||
)
|
||||
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
final_search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageStart(content="", final_documents=saved_search_docs),
|
||||
writer,
|
||||
)
|
||||
start_final_answer_streaming_set = True
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(
|
||||
content=response_part.answer_piece, type="message_delta"
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
else:
|
||||
write_custom_event(
|
||||
ind,
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
write_custom_event(
|
||||
ind,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ class CoreState(BaseModel):
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
current_step_nr: int = 1
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
|
||||
54
backend/onyx/agents/agent_search/dr/conditional_edges.py
Normal file
54
backend/onyx/agents/agent_search/dr/conditional_edges.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool = state.tools_used[-1]
|
||||
try:
|
||||
next_path = DRPath(next_tool)
|
||||
except ValueError:
|
||||
next_path = DRPath.GENERIC_TOOL
|
||||
|
||||
# handle END
|
||||
if next_path == DRPath.END:
|
||||
return END
|
||||
|
||||
# handle invalid paths
|
||||
if next_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.INTERNET_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
|
||||
return next_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return END
|
||||
30
backend/onyx/agents/agent_search/dr/constants.py
Normal file
30
backend/onyx/agents/agent_search/dr/constants.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "HIGH_LEVEL PLAN:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.INTERNET_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 6.0,
|
||||
}
|
||||
114
backend/onyx/agents/agent_search/dr/dr_prompt_builder.py
Normal file
114
backend/onyx/agents/agent_search/dr/dr_prompt_builder.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
|
||||
tool_differentiations: list[str] = []
|
||||
for tool_1 in available_tools:
|
||||
for tool_2 in available_tools:
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS:
|
||||
tool_differentiations.append(
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
)
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools:
|
||||
if not entity_types_string or not relationship_types_string:
|
||||
raise ValueError(
|
||||
"Entity types and relationship types must be provided if the Knowledge Graph is used."
|
||||
)
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string,
|
||||
possible_relationships=relationship_types_string,
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
28
backend/onyx/agents/agent_search/dr/enums.py
Normal file
28
backend/onyx/agents/agent_search/dr/enums.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
"""Research answer purpose options for agent search operations"""
|
||||
|
||||
ANSWER = "ANSWER"
|
||||
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph"
|
||||
INTERNET_SEARCH = "Internet Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
CLOSER = "Closer"
|
||||
END = "End"
|
||||
80
backend/onyx/agents/agent_search/dr/graph_builder.py
Normal file
80
backend/onyx/agents/agent_search/dr/graph_builder.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.dr.states import MainInput
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
|
||||
dr_is_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_is_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
|
||||
return graph
|
||||
108
backend/onyx/agents/agent_search/dr/models.py
Normal file
108
backend/onyx/agents/agent_search/dr/models.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
is_internet_marker_dict: dict[str, bool]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
572
backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py
Normal file
572
backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py
Normal file
@@ -0,0 +1,572 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import GENERAL_DR_ANSWER_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _format_tool_name(tool_name: str) -> str:
|
||||
"""Convert tool name to LLM-friendly format."""
|
||||
name = tool_name.replace(" ", "_")
|
||||
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
|
||||
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
|
||||
return name.upper()
|
||||
|
||||
|
||||
def _get_available_tools(
|
||||
graph_config: GraphConfig, kg_enabled: bool
|
||||
) -> dict[str, OrchestratorTool]:
|
||||
|
||||
available_tools: dict[str, OrchestratorTool] = {}
|
||||
for tool in graph_config.tooling.tools:
|
||||
tool_info = OrchestratorTool(
|
||||
tool_id=tool.id,
|
||||
name=tool.name,
|
||||
llm_path=_format_tool_name(tool.name),
|
||||
path=DRPath.GENERIC_TOOL,
|
||||
description=tool.description,
|
||||
metadata={},
|
||||
cost=1.0,
|
||||
tool_object=tool,
|
||||
)
|
||||
|
||||
if isinstance(tool, CustomTool):
|
||||
# tool_info.metadata["summary_signature"] = CUSTOM_TOOL_RESPONSE_ID
|
||||
pass
|
||||
elif isinstance(tool, InternetSearchTool):
|
||||
# tool_info.metadata["summary_signature"] = (
|
||||
# INTERNET_SEARCH_RESPONSE_SUMMARY_ID
|
||||
# )
|
||||
tool_info.llm_path = DRPath.INTERNET_SEARCH.value
|
||||
tool_info.path = DRPath.INTERNET_SEARCH
|
||||
elif isinstance(tool, SearchTool):
|
||||
# tool_info.metadata["summary_signature"] = SEARCH_RESPONSE_SUMMARY_ID
|
||||
tool_info.llm_path = DRPath.INTERNAL_SEARCH.value
|
||||
tool_info.path = DRPath.INTERNAL_SEARCH
|
||||
elif isinstance(tool, KnowledgeGraphTool):
|
||||
if not kg_enabled:
|
||||
logger.warning("KG must be enabled to use KG search tool, skipping")
|
||||
continue
|
||||
tool_info.llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
tool_info.path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
tool_info.llm_path = DRPath.IMAGE_GENERATION.value
|
||||
tool_info.path = DRPath.IMAGE_GENERATION
|
||||
else:
|
||||
logger.warning(f"Tool {tool.name} ({type(tool)}) is not supported")
|
||||
continue
|
||||
|
||||
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
|
||||
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
|
||||
|
||||
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
|
||||
available_tools[tool_info.llm_path] = tool_info
|
||||
|
||||
# make sure KG isn't enabled without internal search
|
||||
if (
|
||||
DRPath.KNOWLEDGE_GRAPH.value in available_tools
|
||||
and DRPath.INTERNAL_SEARCH.value not in available_tools
|
||||
):
|
||||
raise ValueError(
|
||||
"The Knowledge Graph is not supported without internal search tool"
|
||||
)
|
||||
|
||||
# add CLOSER tool, which is always available
|
||||
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
|
||||
tool_id=-1,
|
||||
name="closer",
|
||||
llm_path=DRPath.CLOSER.value,
|
||||
path=DRPath.CLOSER,
|
||||
description=TOOL_DESCRIPTION[DRPath.CLOSER],
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
return available_tools
|
||||
|
||||
|
||||
def _get_existing_clarification_request(
|
||||
graph_config: GraphConfig,
|
||||
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
|
||||
"""
|
||||
Returns the clarification info, original question, and updated chat history if
|
||||
a clarification request and response exists, otherwise returns None.
|
||||
"""
|
||||
# check for clarification request and response in message history
|
||||
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
|
||||
|
||||
if len(previous_raw_messages) == 0 or (
|
||||
previous_raw_messages[-1].research_answer_purpose
|
||||
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
|
||||
):
|
||||
return None
|
||||
|
||||
# get the clarification request and response
|
||||
previous_messages = graph_config.inputs.prompt_builder.message_history
|
||||
last_message = previous_raw_messages[-1].message
|
||||
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=last_message.strip(),
|
||||
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
chat_history_string = "(No chat history yet available)"
|
||||
|
||||
# get the original user query and chat history string before the original query
|
||||
# e.g., if history = [user query, assistant clarification request, user clarification response],
|
||||
# previous_messages = [user query, assistant clarification request], we want the user query
|
||||
for i, message in enumerate(reversed(previous_messages), 1):
|
||||
if (
|
||||
isinstance(message, HumanMessage)
|
||||
and message.content
|
||||
and isinstance(message.content, str)
|
||||
):
|
||||
original_question = message.content
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history[:-i],
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
break
|
||||
|
||||
return clarification, original_question, chat_history_string
|
||||
|
||||
|
||||
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_any_knowledge_retrieval_and_any_action_tool",
|
||||
"description": "Use this tool to get any external information \
|
||||
that is relevant to the question, or for any action to be taken.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The request to be made to the tool",
|
||||
},
|
||||
},
|
||||
"required": ["request"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def clarifier(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationSetup:
|
||||
"""
|
||||
Perform a quick search on the question as is and see whether a set of clarification
|
||||
questions is needed. For now this is based on the models
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
message_id = graph_config.persistence.message_id
|
||||
|
||||
# get the connected tools and format for the Deep Research flow
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
available_tools = _get_available_tools(graph_config, kg_enabled)
|
||||
|
||||
non_internal_search_tools = [
|
||||
tool
|
||||
for tool in available_tools.values()
|
||||
if tool.path != DRPath.INTERNAL_SEARCH and tool.path != DRPath.KNOWLEDGE_GRAPH
|
||||
]
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
db_session = graph_config.persistence.db_session
|
||||
active_source_types = fetch_unique_document_sources(db_session)
|
||||
|
||||
# if not active_source_types:
|
||||
# raise ValueError("No active source types found")
|
||||
|
||||
active_source_types_descriptions = [
|
||||
DocumentSourceDescription[source_type] for source_type in active_source_types
|
||||
]
|
||||
|
||||
if graph_config.inputs.persona and len(graph_config.inputs.persona.prompts) > 0:
|
||||
assistant_system_prompt = (
|
||||
graph_config.inputs.persona.prompts[0].system_prompt
|
||||
or DEFAULT_DR_SYSTEM_PROMPT
|
||||
) + "\n\n"
|
||||
if graph_config.inputs.persona.prompts[0].task_prompt:
|
||||
assistant_task_prompt = (
|
||||
"\n\nHere are more specifications from the user:\n\n"
|
||||
+ graph_config.inputs.persona.prompts[0].task_prompt
|
||||
)
|
||||
else:
|
||||
assistant_task_prompt = ""
|
||||
|
||||
else:
|
||||
assistant_system_prompt = DEFAULT_DR_SYSTEM_PROMPT + "\n\n"
|
||||
assistant_task_prompt = ""
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
if len(available_tools) == 0 or (
|
||||
len(non_internal_search_tools) == 0 and len(active_source_types) == 0
|
||||
):
|
||||
answer_prompt = GENERAL_DR_ANSWER_PROMPT.build(
|
||||
question=original_question, chat_history_string=chat_history_string
|
||||
)
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, answer_prompt + assistant_task_prompt
|
||||
),
|
||||
tools=None,
|
||||
tool_choice=(None),
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
full_response = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = full_response.full_answer
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
elif not use_tool_calling_llm:
|
||||
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
|
||||
question=original_question, chat_history_string=chat_history_string
|
||||
)
|
||||
|
||||
initial_decision_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt + EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING,
|
||||
decision_prompt + assistant_task_prompt,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
initial_decision_str = cast(str, merge_content(*initial_decision_tokens))
|
||||
|
||||
if len(initial_decision_str.replace(" ", "")) > 0:
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
|
||||
question=original_question, chat_history_string=chat_history_string
|
||||
)
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt + EVAL_SYSTEM_PROMPT_W_TOOL_CALLING,
|
||||
decision_prompt + assistant_task_prompt,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
|
||||
full_response = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = full_response.full_answer
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
# Continue, as external knowledge is required.
|
||||
|
||||
clarification = None
|
||||
|
||||
if research_type != ResearchType.THOUGHTFUL:
|
||||
result = _get_existing_clarification_request(graph_config)
|
||||
if result is not None:
|
||||
clarification, original_question, chat_history_string = result
|
||||
else:
|
||||
# generate clarification questions if needed
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
base_clarification_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.CLARIFICATION,
|
||||
research_type,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
clarification_prompt = base_clarification_prompt.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
clarification_response = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, clarification_prompt
|
||||
),
|
||||
schema=ClarificationGenerationResponse,
|
||||
timeout_override=25,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in clarification generation: {e}")
|
||||
raise e
|
||||
|
||||
if (
|
||||
clarification_response.clarification_needed
|
||||
and clarification_response.clarification_question
|
||||
):
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=clarification_response.clarification_question,
|
||||
clarification_response=None,
|
||||
)
|
||||
write_custom_event(
|
||||
0,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=None,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
0,
|
||||
MessageDelta(
|
||||
content=clarification_response.clarification_question,
|
||||
type="message_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
0,
|
||||
SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
1,
|
||||
OverallStop(),
|
||||
writer,
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification_response.clarification_question,
|
||||
update_parent_message=True,
|
||||
research_type=research_type,
|
||||
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
else:
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
if (
|
||||
clarification
|
||||
and clarification.clarification_question
|
||||
and clarification.clarification_response is None
|
||||
):
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification.clarification_question,
|
||||
update_parent_message=True,
|
||||
research_type=research_type,
|
||||
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
next_tool = DRPath.END.value
|
||||
else:
|
||||
next_tool = DRPath.ORCHESTRATOR.value
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
tools_used=[next_tool],
|
||||
query_list=[],
|
||||
iteration_nr=0,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="clarifier",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
available_tools=available_tools,
|
||||
active_source_types=active_source_types,
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
445
backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py
Normal file
445
backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py
Normal file
@@ -0,0 +1,445 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
|
||||
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
|
||||
from onyx.agents.agent_search.dr.states import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import create_tool_call_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def orchestrator(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to decide the next step in the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.original_question
|
||||
if not question:
|
||||
raise ValueError("Question is required for orchestrator")
|
||||
|
||||
plan_of_record = state.plan_of_record
|
||||
clarification = state.clarification
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
iteration_nr = state.iteration_nr + 1
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
remaining_time_budget = state.remaining_time_budget
|
||||
chat_history_string = state.chat_history_string or "(No chat history yet available)"
|
||||
answer_history_string = (
|
||||
aggregate_context(state.iteration_responses, include_documents=True).context
|
||||
or "(No answer history yet available)"
|
||||
)
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
questions = [
|
||||
f"{iteration_response.tool}: {iteration_response.question}"
|
||||
for iteration_response in state.iteration_responses
|
||||
if len(iteration_response.question) > 0
|
||||
]
|
||||
|
||||
question_history_string = (
|
||||
"\n".join(f" - {question}" for question in questions)
|
||||
if questions
|
||||
else "(No question history yet available)"
|
||||
)
|
||||
|
||||
prompt_question = get_prompt_question(question, clarification)
|
||||
|
||||
gaps_str = (
|
||||
("\n - " + "\n - ".join(state.gaps))
|
||||
if state.gaps
|
||||
else "(No explicit gaps were pointed out so far)"
|
||||
)
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# default to closer
|
||||
next_tool = DRPath.CLOSER.value
|
||||
query_list = ["Answer the question with the information you have."]
|
||||
decision_prompt = None
|
||||
|
||||
reasoning_result = "(No reasoning result provided yet.)"
|
||||
tool_calls_string = "(No tool calls provided yet.)"
|
||||
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
if iteration_nr == 1:
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
|
||||
|
||||
elif iteration_nr > 1:
|
||||
# for each iteration past the first one, we need to see whether we
|
||||
# have enough information to answer the question.
|
||||
# if we do, we can stop the iteration and return the answer.
|
||||
# if we do not, we need to continue the iteration.
|
||||
|
||||
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_REASONING,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
|
||||
reasoning_prompt = base_reasoning_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
)
|
||||
|
||||
reasoning_tokens: list[str] = [""]
|
||||
|
||||
reasoning_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
reasoning_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
answer_piece="reasoning_delta",
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
reasoning_result = cast(str, merge_content(*reasoning_tokens))
|
||||
|
||||
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
current_step_nr=current_step_nr,
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=None,
|
||||
reasoning=reasoning_result,
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
reasoning_result=reasoning_result,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
if decision_prompt is None:
|
||||
raise ValueError("Decision prompt is required")
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
decision_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=35,
|
||||
# max_tokens=2500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool = next_step.tool
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool, query_list)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
remaining_time_budget -= available_tools[next_tool].cost
|
||||
else:
|
||||
if iteration_nr == 1 and not plan_of_record:
|
||||
# by default, we start a new iteration, but if there is a feedback request,
|
||||
# we start a new iteration 0 again (set a bit later)
|
||||
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
|
||||
|
||||
base_plan_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.PLAN,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
plan_generation_prompt = base_plan_prompt.build(
|
||||
question=prompt_question,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
plan_of_record = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
plan_generation_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=OrchestrationPlan,
|
||||
timeout_override=25,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in plan generation: {e}")
|
||||
raise
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(
|
||||
type="reasoning_start",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningDelta(
|
||||
reasoning=f"{HIGH_LEVEL_PLAN_PREFIX} {plan_of_record.plan}\n\n",
|
||||
type="reasoning_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
current_step_nr += 1
|
||||
|
||||
if not plan_of_record:
|
||||
raise ValueError(
|
||||
"Plan information is required for iterative decision making"
|
||||
)
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
answer_history_string=answer_history_string,
|
||||
question_history_string=question_history_string,
|
||||
question=prompt_question,
|
||||
iteration_nr=str(iteration_nr),
|
||||
current_plan_of_record_string=plan_of_record.plan,
|
||||
chat_history_string=chat_history_string,
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
gaps=gaps_str,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
if decision_prompt is None:
|
||||
raise ValueError("Decision prompt is required")
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
decision_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=15,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool = next_step.tool
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
reasoning_result = orchestrator_action.reasoning
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool, query_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
remaining_time_budget -= available_tools[next_tool].cost
|
||||
else:
|
||||
reasoning_result = "Time to wrap up."
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(
|
||||
type="reasoning_start",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningDelta(
|
||||
reasoning=reasoning_result,
|
||||
type="reasoning_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_PURPOSE,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
|
||||
question=prompt_question,
|
||||
reasoning_result=reasoning_result,
|
||||
tool_calls=tool_calls_string,
|
||||
)
|
||||
|
||||
purpose_tokens: list[str] = [""]
|
||||
|
||||
try:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(
|
||||
type="reasoning_start",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
purpose_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
orchestration_next_step_purpose_prompt
|
||||
+ (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
answer_piece="reasoning_delta",
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in orchestration next step purpose: {e}")
|
||||
raise e
|
||||
|
||||
purpose = cast(str, merge_content(*purpose_tokens))
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[next_tool],
|
||||
query_list=query_list or [],
|
||||
iteration_nr=iteration_nr,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning=reasoning_result,
|
||||
purpose=purpose,
|
||||
)
|
||||
],
|
||||
)
|
||||
381
backend/onyx/agents/agent_search/dr/nodes/dr_a2_closer.py
Normal file
381
backend/onyx/agents/agent_search/dr/nodes/dr_a2_closer.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
|
||||
|
||||
def insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
|
||||
cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
|
||||
citation_dict: dict[str | int, int] = {}
|
||||
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
parent_question_id=None,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
additional_data=iteration_answer.additional_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
iteration_responses_string = aggregated_context.context
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=40,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
all_cited_documents
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
else:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
240,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
final_answer_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
answer_piece="message_delta",
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
replace_citations=True,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
final_answer = "".join(streamed_output)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(
|
||||
state,
|
||||
graph_config,
|
||||
aggregated_context,
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
)
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
79
backend/onyx/agents/agent_search/dr/states.py
Normal file
79
backend/onyx/agents/agent_search/dr/states.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
|
||||
|
||||
class OrchestrationSetup(OrchestrationUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationSetup,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,232 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
)
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=5,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = [
|
||||
DocumentSource(source_type)
|
||||
for source_type in search_processing.specified_source_types
|
||||
]
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
|
||||
break
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,99 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
queries = [update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = chunks_or_sections_to_search_docs(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
|
||||
for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
retrieved_saved_search_doc.is_internet = False
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
type="internal_search_tool_start",
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=retrieved_saved_search_docs,
|
||||
type="internal_search_tool_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.states import AnswerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> AnswerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.llm_path
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_response="(No tool response yet. You need to call the tool to answer the question.)",
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=40,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
# summarise tool result
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
tool_summary_prompt, timeout_override=40
|
||||
).content
|
||||
).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return AnswerUpdate(
|
||||
iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
state.assistant_system_prompt
|
||||
state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.prompt_builder.raw_user_query
|
||||
graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Generate images using the image generation tool
|
||||
generated_images: list[ImageGenerationResponse] = []
|
||||
|
||||
for tool_response in image_tool.run(prompt=branch_query):
|
||||
if tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
generated_images = response
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Create answer string describing the generated images
|
||||
if generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(generated_images, 1):
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(generated_images)} image(s) based on the request: {branch_query}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
reasoning = f"Used image generation tool to create {len(generated_images)} image(s) based on the user's request."
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {branch_query}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=image_tool_info.llm_path,
|
||||
tool_id=image_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning=reasoning,
|
||||
additional_data=(
|
||||
{"generated_images": str(len(generated_images))}
|
||||
if generated_images
|
||||
else None
|
||||
),
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="generating",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolStart(
|
||||
type="image_generation_tool_start",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolDelta(
|
||||
images={},
|
||||
type="image_generation_tool_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
|
||||
image_generation_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
|
||||
image_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_image_generation_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", image_generation_branch)
|
||||
|
||||
graph.add_node("act", image_generation)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Internet Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,175 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def internet_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
internet_search_tool = cast(InternetSearchTool, is_tool_info.tool_object)
|
||||
|
||||
if internet_search_tool.provider is None:
|
||||
raise ValueError(
|
||||
"internet_search_tool.provider is not set. This should not happen."
|
||||
)
|
||||
|
||||
# Update search parameters
|
||||
internet_search_tool.max_chunks = 10
|
||||
internet_search_tool.provider.num_results = 10
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
for tool_response in internet_search_tool.run(internet_search_query=search_query):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
# stream_write_step_answer_explicit(writer, step_nr=1, answer=full_answer)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=search_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
reasoning = ""
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
queries = [update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
doc_list, is_internet=True
|
||||
)
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
type="internal_search_tool_start",
|
||||
is_internet_search=True,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=retrieved_search_docs,
|
||||
type="internal_search_tool_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_2_act import (
|
||||
internet_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_is_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("act", internet_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
state.current_step_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
logger.debug(
|
||||
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
kg_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
kb_graph = kb_graph_builder().compile()
|
||||
|
||||
kb_results = kb_graph.invoke(
|
||||
input=KbMainInput(question=search_query, individual_flow=False),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = kb_results.get("final_answer") or "No answer provided"
|
||||
claims: list[str] = []
|
||||
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
# if citation is empty, the answer must have come from the KG rather than a doc
|
||||
# in that case, simply cite the docs returned by the KG
|
||||
if not citation_numbers:
|
||||
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
if citation_number <= len(retrieved_docs)
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=kg_tool_info.llm_path,
|
||||
tool_id=kg_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=None,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,124 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
|
||||
|
||||
|
||||
def kg_search_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
queries = [update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
|
||||
|
||||
if len(queries) == 1:
|
||||
kg_answer: str | None = (
|
||||
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
|
||||
)
|
||||
else:
|
||||
kg_answer = None
|
||||
|
||||
if len(retrieved_search_docs) > 0:
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
type="internal_search_tool_start",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=retrieved_search_docs,
|
||||
type="internal_search_tool_delta",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
if kg_answer is not None:
|
||||
|
||||
kg_display_answer = (
|
||||
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
|
||||
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
|
||||
else kg_answer
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningDelta(reasoning=kg_display_answer, type="reasoning_delta"),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel search for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
|
||||
kg_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
|
||||
kg_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
|
||||
kg_search_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_kg_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for KG Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", kg_search_branch)
|
||||
|
||||
graph.add_node("act", kg_search)
|
||||
|
||||
graph.add_node("reducer", kg_search_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
46
backend/onyx/agents/agent_search/dr/sub_agents/states.py
Normal file
46
backend/onyx/agents/agent_search/dr/sub_agents/states.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
|
||||
class SubAgentUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
current_step_nr: int = 1
|
||||
|
||||
|
||||
class BranchUpdate(LoggerUpdate):
|
||||
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class SubAgentInput(LoggerUpdate):
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
query_list: list[str] = []
|
||||
context: str | None = None
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
|
||||
|
||||
class SubAgentMainState(
|
||||
# This includes the core state
|
||||
SubAgentInput,
|
||||
SubAgentUpdate,
|
||||
BranchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class BranchInput(SubAgentInput):
|
||||
parallelization_nr: int = 0
|
||||
branch_question: str | None = None
|
||||
|
||||
|
||||
class CustomToolBranchInput(LoggerUpdate):
|
||||
tool_info: OrchestratorTool
|
||||
343
backend/onyx/agents/agent_search/dr/utils.py
Normal file
343
backend/onyx/agents/agent_search/dr/utils.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import re
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import SearchDoc
|
||||
|
||||
|
||||
CITATION_PREFIX = "CITE:"
|
||||
|
||||
|
||||
def extract_document_citations(
|
||||
answer: str, claims: list[str]
|
||||
) -> tuple[list[int], str, list[str]]:
|
||||
"""
|
||||
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
|
||||
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
|
||||
etc., to help with citation deduplication later on.
|
||||
"""
|
||||
citations: set[int] = set()
|
||||
|
||||
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
|
||||
# This regex matches:
|
||||
# - \[(\d+)\] for single citations like [1]
|
||||
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
|
||||
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||
|
||||
def _extract_and_replace(match: re.Match[str]) -> str:
|
||||
numbers = [int(num) for num in match.group(1).split(",")]
|
||||
citations.update(numbers)
|
||||
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
|
||||
|
||||
new_answer = pattern.sub(_extract_and_replace, answer)
|
||||
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
|
||||
|
||||
return list(citations), new_answer, new_claims
|
||||
|
||||
|
||||
def aggregate_context(
|
||||
iteration_responses: list[IterationAnswer], include_documents: bool = True
|
||||
) -> AggregatedDRContext:
|
||||
"""
|
||||
Converts the iteration response into a single string with unified citations.
|
||||
For example,
|
||||
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
|
||||
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
|
||||
Output:
|
||||
it 1: the answer is x [1][2].
|
||||
it 2: blah blah [2][3]
|
||||
[1]: doc_xyz
|
||||
[2]: doc_abc
|
||||
[3]: doc_pqr
|
||||
"""
|
||||
# dedupe and merge inference section contents
|
||||
unrolled_inference_sections: list[InferenceSection] = []
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
|
||||
iteration_tool = iteration_response.tool
|
||||
if iteration_tool == "InternetSearchTool":
|
||||
is_internet = True
|
||||
else:
|
||||
is_internet = False
|
||||
|
||||
for cited_doc in iteration_response.cited_documents.values():
|
||||
unrolled_inference_sections.append(cited_doc)
|
||||
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
|
||||
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
|
||||
is_internet
|
||||
)
|
||||
cited_doc.center_chunk.score = None # None means maintain order
|
||||
|
||||
global_documents = dedup_inference_section_list(unrolled_inference_sections)
|
||||
|
||||
global_citations = {
|
||||
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
|
||||
}
|
||||
|
||||
# build output string
|
||||
output_strings: list[str] = []
|
||||
global_iteration_responses: list[IterationAnswer] = []
|
||||
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
# add basic iteration info
|
||||
output_strings.append(
|
||||
f"Iteration: {iteration_response.iteration_nr}, "
|
||||
f"Question {iteration_response.parallelization_nr}"
|
||||
)
|
||||
output_strings.append(f"Tool: {iteration_response.tool}")
|
||||
output_strings.append(f"Question: {iteration_response.question}")
|
||||
|
||||
# get answer and claims with global citations
|
||||
answer_str = iteration_response.answer
|
||||
claims = iteration_response.claims or []
|
||||
|
||||
iteration_citations: list[int] = []
|
||||
for local_number, cited_doc in iteration_response.cited_documents.items():
|
||||
global_number = global_citations[cited_doc.center_chunk.document_id]
|
||||
# translate local citations to global citations
|
||||
answer_str = answer_str.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
claims = [
|
||||
claim.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
iteration_citations.append(global_number)
|
||||
|
||||
# add answer, claims, and citation info
|
||||
if answer_str:
|
||||
output_strings.append(f"Answer: {answer_str}")
|
||||
if claims:
|
||||
output_strings.append(
|
||||
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
|
||||
or "No claims provided"
|
||||
)
|
||||
if not answer_str and not claims:
|
||||
output_strings.append(
|
||||
"Retrieved documents: "
|
||||
+ (
|
||||
"".join(
|
||||
f"[{global_number}]"
|
||||
for global_number in sorted(iteration_citations)
|
||||
)
|
||||
or "No documents retrieved"
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
# save global iteration response
|
||||
global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=iteration_response.tool,
|
||||
tool_id=iteration_response.tool_id,
|
||||
iteration_nr=iteration_response.iteration_nr,
|
||||
parallelization_nr=iteration_response.parallelization_nr,
|
||||
question=iteration_response.question,
|
||||
reasoning=iteration_response.reasoning,
|
||||
answer=answer_str,
|
||||
cited_documents={
|
||||
global_citations[doc.center_chunk.document_id]: doc
|
||||
for doc in iteration_response.cited_documents.values()
|
||||
},
|
||||
background_info=iteration_response.background_info,
|
||||
claims=claims,
|
||||
additional_data=iteration_response.additional_data,
|
||||
)
|
||||
)
|
||||
|
||||
# add document contents if requested
|
||||
if include_documents:
|
||||
if global_documents:
|
||||
output_strings.append("Cited document contents:")
|
||||
for doc in global_documents:
|
||||
output_strings.append(
|
||||
build_document_context(
|
||||
doc, global_citations[doc.center_chunk.document_id]
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
return AggregatedDRContext(
|
||||
context="\n".join(output_strings),
|
||||
cited_documents=global_documents,
|
||||
is_internet_marker_dict=is_internet_marker_dict,
|
||||
global_iteration_responses=global_iteration_responses,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
|
||||
"""
|
||||
Get the chat history (up to max_messages) as a string.
|
||||
"""
|
||||
# get past max_messages USER, ASSISTANT message pairs
|
||||
past_messages = chat_history[-max_messages * 2 :]
|
||||
return (
|
||||
"...\n"
|
||||
if len(chat_history) > len(past_messages)
|
||||
else ""
|
||||
"\n".join(
|
||||
("user" if isinstance(msg, HumanMessage) else "you")
|
||||
+ f": {str(msg.content).strip()}"
|
||||
for msg in past_messages
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_question(
|
||||
question: str, clarification: OrchestrationClarificationInfo | None
|
||||
) -> str:
|
||||
if clarification:
|
||||
clarification_question = clarification.clarification_question
|
||||
clarification_response = clarification.clarification_response
|
||||
return (
|
||||
f"Initial User Question: {question}\n"
|
||||
f"(Clarification Question: {clarification_question}\n"
|
||||
f"User Response: {clarification_response})"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
|
||||
"""
|
||||
Create a string representation of the tool call.
|
||||
"""
|
||||
questions_str = "\n - ".join(query_list)
|
||||
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
|
||||
|
||||
|
||||
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
|
||||
# Convert plan string to numbered dict format
|
||||
if not plan_text:
|
||||
return {}
|
||||
|
||||
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
|
||||
parts = re.split(r"(\d+[.)])", plan_text)
|
||||
plan_dict = {}
|
||||
|
||||
for i in range(
|
||||
1, len(parts), 2
|
||||
): # Skip empty first part, then take number and text pairs
|
||||
if i + 1 < len(parts):
|
||||
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
|
||||
text = parts[i + 1].strip()
|
||||
if text: # Only add if there's actual content
|
||||
plan_dict[number] = text
|
||||
|
||||
return plan_dict
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
inference_sections: list[InferenceSection],
|
||||
is_internet: bool = False,
|
||||
) -> list[SavedSearchDoc]:
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = chunks_or_sections_to_search_docs(inference_sections)
|
||||
for search_doc in search_docs:
|
||||
search_doc.is_internet = is_internet
|
||||
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
return retrieved_saved_search_docs
|
||||
|
||||
|
||||
def update_db_session_with_messages(
|
||||
db_session: Session,
|
||||
chat_message_id: int,
|
||||
chat_session_id: str,
|
||||
is_agentic: bool | None,
|
||||
message: str | None = None,
|
||||
message_type: str | None = None,
|
||||
token_count: int | None = None,
|
||||
rephrased_query: str | None = None,
|
||||
prompt_id: int | None = None,
|
||||
citations: dict[str | int, int] | None = None,
|
||||
error: str | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
research_type: str | None = None,
|
||||
research_plan: dict[str, str] | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
update_parent_message: bool = True,
|
||||
research_answer_purpose: ResearchAnswerPurpose | None = None,
|
||||
) -> None:
|
||||
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(
|
||||
ChatMessage.id == chat_message_id,
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise ValueError("Chat message with id not found") # should never happen
|
||||
|
||||
if message:
|
||||
chat_message.message = message
|
||||
if message_type:
|
||||
chat_message.message_type = MessageType(message_type)
|
||||
if token_count:
|
||||
chat_message.token_count = token_count
|
||||
if rephrased_query:
|
||||
chat_message.rephrased_query = rephrased_query
|
||||
if prompt_id:
|
||||
chat_message.prompt_id = prompt_id
|
||||
if citations:
|
||||
# Convert string keys to integers to match database field type
|
||||
chat_message.citations = {int(k): v for k, v in citations.items()}
|
||||
if error:
|
||||
chat_message.error = error
|
||||
if alternate_assistant_id:
|
||||
chat_message.alternate_assistant_id = alternate_assistant_id
|
||||
if overridden_model:
|
||||
chat_message.overridden_model = overridden_model
|
||||
if research_type:
|
||||
chat_message.research_type = ResearchType(research_type)
|
||||
if research_plan:
|
||||
chat_message.research_plan = research_plan
|
||||
if final_documents:
|
||||
chat_message.search_docs = final_documents
|
||||
if is_agentic:
|
||||
chat_message.is_agentic = is_agentic
|
||||
|
||||
if research_answer_purpose:
|
||||
chat_message.research_answer_purpose = research_answer_purpose
|
||||
|
||||
if update_parent_message:
|
||||
parent_chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == chat_message.parent_message)
|
||||
.first()
|
||||
)
|
||||
if parent_chat_message:
|
||||
parent_chat_message.latest_child_message = chat_message.id
|
||||
|
||||
return
|
||||
@@ -6,7 +6,12 @@ from langgraph.types import StreamWriter
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
KG_SEARCH_STEP_DESCRIPTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -95,14 +100,14 @@ def create_minimal_connected_query_graph(
|
||||
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
|
||||
|
||||
def stream_write_step_description(
|
||||
def stream_write_kg_search_description(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=STEP_DESCRIPTIONS[step_nr].description,
|
||||
sub_question=KG_SEARCH_STEP_DESCRIPTIONS[step_nr].description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
@@ -113,10 +118,12 @@ def stream_write_step_description(
|
||||
sleep(0.2)
|
||||
|
||||
|
||||
def stream_write_step_activities(
|
||||
def stream_write_kg_search_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
|
||||
for activity_nr, activity in enumerate(
|
||||
KG_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
|
||||
):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
@@ -129,23 +136,25 @@ def stream_write_step_activities(
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_activity_explicit(
|
||||
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
|
||||
def stream_write_basic_search_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity in STEP_DESCRIPTIONS[step_nr].activities:
|
||||
for activity_nr, activity in enumerate(
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
|
||||
):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=query_id,
|
||||
query_id=activity_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_answer_explicit(
|
||||
def stream_write_kg_search_answer_explicit(
|
||||
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
@@ -160,8 +169,8 @@ def stream_write_step_answer_explicit(
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
|
||||
def stream_write_kg_search_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in KG_SEARCH_STEP_DESCRIPTIONS.items():
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
@@ -173,7 +182,7 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in STEP_DESCRIPTIONS.keys():
|
||||
for step_nr in KG_SEARCH_STEP_DESCRIPTIONS.keys():
|
||||
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
@@ -195,7 +204,40 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_close_step_answer(
|
||||
def stream_write_basic_search_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in BASIC_SEARCH_STEP_DESCRIPTIONS.items():
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=step_detail.description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in BASIC_SEARCH_STEP_DESCRIPTIONS:
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_kg_search_close_step_answer(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
@@ -207,7 +249,7 @@ def stream_close_step_answer(
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
def stream_write_kg_search_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
@@ -355,7 +397,7 @@ def get_near_empty_step_results(
|
||||
Get near-empty step results from a list of step results.
|
||||
"""
|
||||
return SubQuestionAnswerResults(
|
||||
question=STEP_DESCRIPTIONS[step_number].description,
|
||||
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
|
||||
question_id="0_" + str(step_number),
|
||||
answer=step_answer,
|
||||
verified_high_quality=True,
|
||||
|
||||
@@ -7,17 +7,23 @@ from langgraph.types import StreamWriter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_structure,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@@ -42,7 +48,7 @@ logger = setup_logger()
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
) -> EntityRelationshipExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
@@ -68,17 +74,17 @@ def extract_ert(
|
||||
user_name = user_email.split("@")[0] or "unknown"
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_step_structure(writer)
|
||||
if state.individual_flow:
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_kg_search_structure(writer)
|
||||
|
||||
# Now specify core activities in the step (step 1)
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -240,12 +246,13 @@ def extract_ert(
|
||||
step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
|
||||
# Finish Step 1
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
# Finish Step 1
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
return EntityRelationshipExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
relationship_types_str=all_relationship_types,
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
|
||||
@@ -9,10 +9,14 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
@@ -141,7 +145,7 @@ def analyze(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
entities = (
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
@@ -150,7 +154,8 @@ def analyze(
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Continue with node
|
||||
|
||||
@@ -277,9 +282,12 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
else:
|
||||
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=step_answer
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# End node
|
||||
|
||||
|
||||
@@ -8,10 +8,14 @@ from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
@@ -33,8 +37,10 @@ from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_
|
||||
from onyx.db.kg_temp_view import drop_views
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import ENTITY_TABLE_DESCRIPTION
|
||||
from onyx.prompts.kg_prompts import RELATIONSHIP_TABLE_DESCRIPTION
|
||||
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_ERROR_FIX_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -122,6 +128,22 @@ def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _run_sql(
|
||||
sql_statement: str, rel_temp_view: str, ent_temp_view: str
|
||||
) -> list[dict[str, Any]]:
|
||||
# check sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(sql_statement, rel_temp_view, ent_temp_view)
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
return [{"count": int(scalar_result)}] if scalar_result is not None else []
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
|
||||
|
||||
def _get_source_documents(
|
||||
sql_statement: str,
|
||||
llm: LLM,
|
||||
@@ -189,7 +211,7 @@ def generate_simple_sql(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
entities_types_str = state.entities_types_str
|
||||
relationship_types_str = state.relationship_types_str
|
||||
|
||||
@@ -199,7 +221,6 @@ def generate_simple_sql(
|
||||
raise ValueError("kg_doc_temp_view_name is not set")
|
||||
if state.kg_rel_temp_view_name is None:
|
||||
raise ValueError("kg_rel_temp_view_name is not set")
|
||||
|
||||
if state.kg_entity_temp_view_name is None:
|
||||
raise ValueError("kg_entity_temp_view_name is not set")
|
||||
|
||||
@@ -207,7 +228,8 @@ def generate_simple_sql(
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
@@ -270,6 +292,12 @@ def generate_simple_sql(
|
||||
)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
else:
|
||||
simple_sql_prompt = (
|
||||
@@ -289,8 +317,7 @@ def generate_simple_sql(
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
# prepare SQL query generation
|
||||
|
||||
# generate initial sql statement
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
@@ -298,7 +325,6 @@ def generate_simple_sql(
|
||||
]
|
||||
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
@@ -336,53 +362,6 @@ def generate_simple_sql(
|
||||
)
|
||||
raise e
|
||||
|
||||
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
|
||||
# Correction if needed:
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
)
|
||||
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
# display sql statement with view names replaced by general view names
|
||||
sql_statement_display = sql_statement.replace(
|
||||
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
|
||||
@@ -437,51 +416,93 @@ def generate_simple_sql(
|
||||
|
||||
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
query_results = [] # if no results, will be empty (not None)
|
||||
query_generation_error = None
|
||||
|
||||
# check sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
sql_statement, rel_temp_view, ent_temp_view
|
||||
)
|
||||
# run sql
|
||||
try:
|
||||
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
|
||||
if not query_results:
|
||||
query_generation_error = "SQL query returned no results"
|
||||
logger.warning(f"{query_generation_error}, retrying...")
|
||||
except Exception as e:
|
||||
query_generation_error = str(e)
|
||||
logger.warning(f"Error executing SQL query: {e}, retrying...")
|
||||
|
||||
# fix sql and try one more time if sql query didn't work out
|
||||
# if the result is still empty after this, the kg probably doesn't have the answer,
|
||||
# so we update the strategy to simple and address this in the answer generation
|
||||
if query_generation_error is not None:
|
||||
sql_fix_prompt = (
|
||||
SIMPLE_SQL_ERROR_FIX_PROMPT.replace(
|
||||
"---table_description---",
|
||||
(
|
||||
ENTITY_TABLE_DESCRIPTION
|
||||
if state.query_type
|
||||
== KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
else RELATIONSHIP_TABLE_DESCRIPTION
|
||||
),
|
||||
)
|
||||
.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---sql_statement---", sql_statement)
|
||||
.replace("---error_message---", query_generation_error)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
msg = [HumanMessage(content=sql_fix_prompt)]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result)}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
|
||||
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
)
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
sql_statement = sql_statement.replace(
|
||||
"relationship_table", rel_temp_view
|
||||
)
|
||||
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
|
||||
|
||||
reasoning = (
|
||||
cleaned_response.split("<reasoning>")[1]
|
||||
.strip()
|
||||
.split("</reasoning>")[0]
|
||||
)
|
||||
|
||||
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query even after retry: {e}")
|
||||
# TODO: raise error on frontend
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
raise
|
||||
|
||||
source_document_results = None
|
||||
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
|
||||
# check source document sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
source_documents_sql, rel_temp_view, ent_temp_view
|
||||
)
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
|
||||
try:
|
||||
result = db_session.execute(text(source_documents_sql))
|
||||
rows = result.fetchall()
|
||||
@@ -491,28 +512,16 @@ def generate_simple_sql(
|
||||
for source_document_result in query_source_document_results
|
||||
]
|
||||
except Exception as e:
|
||||
# TODO: raise error on frontend
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
# TODO: raise warning on frontend
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
|
||||
elif state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
# source documents should be returned for entity queries
|
||||
source_document_results = [
|
||||
x["source_document"] for x in query_results if "source_document" in x
|
||||
]
|
||||
else:
|
||||
|
||||
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
# source documents should be returned for entity queries
|
||||
source_document_results = [
|
||||
x["source_document"]
|
||||
for x in query_results
|
||||
if "source_document" in x
|
||||
]
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
source_document_results = None
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
@@ -528,21 +537,25 @@ def generate_simple_sql(
|
||||
|
||||
main_sql_statement = sql_statement
|
||||
|
||||
if reasoning:
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
|
||||
if reasoning and state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=reasoning
|
||||
)
|
||||
|
||||
if sql_statement_display:
|
||||
stream_write_step_answer_explicit(
|
||||
if sql_statement_display and state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer,
|
||||
step_nr=_KG_STEP_NR,
|
||||
answer=f" \n Generated SQL: {sql_statement_display}",
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# Update path if too many results are retrieved
|
||||
|
||||
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
|
||||
# Update path if too many, or no results were retrieved from sql
|
||||
if main_sql_statement and (
|
||||
not query_results or len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS
|
||||
):
|
||||
updated_strategy = KGAnswerStrategy.SIMPLE
|
||||
else:
|
||||
updated_strategy = None
|
||||
|
||||
@@ -34,7 +34,7 @@ def construct_deep_search_filters(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities_no_attributes
|
||||
@@ -155,7 +155,11 @@ def construct_deep_search_filters(
|
||||
|
||||
if div_con_structure:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
|
||||
# entity_type is guaranteed to have grounded_source_name
|
||||
if (
|
||||
cast(str, entity_type.grounded_source_name).lower()
|
||||
in div_con_structure[0].lower()
|
||||
):
|
||||
source_division = True
|
||||
break
|
||||
|
||||
|
||||
@@ -98,16 +98,17 @@ def process_individual_deep_search(
|
||||
kg_relationship_filters = None
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
|
||||
logger.debug(
|
||||
|
||||
@@ -7,9 +7,11 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
@@ -49,7 +51,7 @@ def filtered_search(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
@@ -72,17 +74,18 @@ def filtered_search(
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
@@ -165,11 +168,12 @@ def filtered_search(
|
||||
|
||||
step_answer = "Filtered search is complete."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=filtered_search_answer,
|
||||
|
||||
@@ -5,9 +5,11 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
@@ -41,11 +43,12 @@ def consolidate_individual_deep_search(
|
||||
|
||||
step_answer = "All research is complete. Consolidating results..."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
|
||||
@@ -4,9 +4,11 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
@@ -66,28 +68,26 @@ def process_kg_only_answers(
|
||||
|
||||
# we use this stream write explicitly
|
||||
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
query_results_list = []
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
)
|
||||
query_results_data_str = "\n".join(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
for query_result in query_results
|
||||
)
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
logger.warning("No query results were found")
|
||||
query_results_data_str = "(No query results were found)"
|
||||
|
||||
source_reference_result_str = _get_formated_source_reference_results(
|
||||
source_document_results
|
||||
@@ -99,9 +99,12 @@ def process_kg_only_answers(
|
||||
"No further research is needed, the answer is derived from the knowledge graph."
|
||||
)
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=step_answer
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
|
||||
@@ -7,14 +7,17 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.access.access import get_acl_for_user
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_close_steps,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -42,7 +45,7 @@ logger = setup_logger()
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
) -> FinalAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
@@ -50,7 +53,9 @@ def generate_answer(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
final_answer: str | None = None
|
||||
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
@@ -69,7 +74,8 @@ def generate_answer(
|
||||
|
||||
# DECLARE STEPS DONE
|
||||
|
||||
stream_write_close_steps(writer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_close_steps(writer)
|
||||
|
||||
## MAIN ANSWER
|
||||
|
||||
@@ -128,16 +134,17 @@ def generate_answer(
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# continue with the answer generation
|
||||
|
||||
@@ -206,24 +213,40 @@ def generate_answer(
|
||||
)
|
||||
]
|
||||
try:
|
||||
run_with_timeout(
|
||||
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
if state.individual_flow:
|
||||
|
||||
stream_results, _, _ = run_with_timeout(
|
||||
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
final_answer = "".join(stream_results)
|
||||
else:
|
||||
final_answer = get_answer_from_llm(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=output_format_prompt,
|
||||
stream=False,
|
||||
json_string_flag=False,
|
||||
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
return MainOutput(
|
||||
return FinalAnswerUpdate(
|
||||
final_answer=final_answer,
|
||||
retrieved_documents=answer_generation_documents.context_documents,
|
||||
step_results=[],
|
||||
remarks=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -48,6 +48,8 @@ def log_data(
|
||||
)
|
||||
|
||||
return MainOutput(
|
||||
final_answer=state.final_answer,
|
||||
retrieved_documents=state.retrieved_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -120,7 +120,7 @@ class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
class EntityRelationshipExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
relationship_types_str: str = ""
|
||||
extracted_entities_w_attributes: list[str] = []
|
||||
@@ -144,7 +144,13 @@ class ResearchObjectUpdate(LoggerUpdate):
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
question: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
|
||||
class FinalAnswerUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
retrieved_documents: list[InferenceSection] | None = None
|
||||
|
||||
|
||||
## Graph State
|
||||
@@ -154,7 +160,7 @@ class MainState(
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
EntityRelationshipExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
@@ -162,6 +168,7 @@ class MainState(
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
FinalAnswerUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -169,6 +176,8 @@ class MainState(
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
retrieved_documents: list[InferenceSection] | None
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
@@ -179,3 +188,4 @@ class ResearchObjectInput(LoggerUpdate):
|
||||
source_division: bool | None
|
||||
source_entity_filters: list[str] | None
|
||||
segment_type: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
|
||||
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(
|
||||
description="Analyzing the question...",
|
||||
activities=[
|
||||
@@ -27,3 +27,7 @@ STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
description="Conducting further research on source documents...", activities=[]
|
||||
),
|
||||
}
|
||||
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(description="Conducting a standard search...", activities=[]),
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.models import Persona
|
||||
@@ -72,6 +73,7 @@ class GraphSearchConfig(BaseModel):
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
research_type: ResearchType = ResearchType.THOUGHTFUL
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -271,7 +271,10 @@ def choose_tool(
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
)
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("No tool message emitted by LLM")
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
|
||||
@@ -4,6 +4,7 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
@@ -21,6 +22,7 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -62,7 +64,9 @@ def basic_use_tool_response(
|
||||
for section in dedupe_documents(search_response_summary.top_sections)[0]
|
||||
]
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
new_tool_call_chunk = BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=AIMessageChunk(content=""), full_answer=None
|
||||
)
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
@@ -80,4 +84,9 @@ def basic_use_tool_response(
|
||||
displayed_search_results=initial_search_results or final_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
return BasicOutput(
|
||||
tool_call_chunk=new_tool_call_chunk.ai_message_chunk,
|
||||
full_answer=new_tool_call_chunk.full_answer,
|
||||
cited_references=new_tool_call_chunk.cited_references,
|
||||
retrieved_documents=new_tool_call_chunk.retrieved_documents,
|
||||
)
|
||||
|
||||
@@ -18,79 +18,37 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
|
||||
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GraphInput = BasicInput | MainInput | DCMainInput | KBMainInput | DRMainInput
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
elif event["name"] == "refined_sub_question_creation_error":
|
||||
return cast(StreamingError, event["data"])
|
||||
else:
|
||||
logger.error(f"Unknown event name: {event['name']}")
|
||||
return None
|
||||
|
||||
logger.error(f"Unknown event type: {event_type}")
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
graph_input: GraphInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -104,16 +62,14 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
input: GraphInput,
|
||||
) -> AnswerStream:
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
yield cast(Packet, event["data"])
|
||||
|
||||
|
||||
# It doesn't actually take very long to load the graph, but we'd rather
|
||||
@@ -154,16 +110,23 @@ def run_kb_graph(
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
|
||||
input = KBMainInput(
|
||||
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
|
||||
)
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dr_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = dr_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = DRMainInput(log_messages=[])
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
|
||||
from onyx.chat.stream_processing.citation_processing import LlmDoc
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# match ```json{...}``` or ```{...}```
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
def stream_llm_answer(
|
||||
@@ -19,7 +39,11 @@ def stream_llm_answer(
|
||||
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[list[str], list[float]]:
|
||||
answer_piece: str | None = None,
|
||||
ind: int | None = None,
|
||||
context_docs: list[LlmDoc] | None = None,
|
||||
replace_citations: bool = False,
|
||||
) -> tuple[list[str], list[float], list[CitationInfo]]:
|
||||
"""Stream the initial answer from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -32,16 +56,32 @@ def stream_llm_answer(
|
||||
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
|
||||
timeout_override: The LLM timeout to use.
|
||||
max_tokens: The LLM max tokens to use.
|
||||
answer_piece: The type of answer piece to write.
|
||||
ind: The index of the answer piece.
|
||||
tools: The tools to use.
|
||||
tool_choice: The tool choice to use.
|
||||
structured_response_format: The structured response format to use.
|
||||
|
||||
Returns:
|
||||
A tuple of the response and the dispatch timings.
|
||||
"""
|
||||
response: list[str] = []
|
||||
dispatch_timings: list[float] = []
|
||||
citation_infos: list[CitationInfo] = []
|
||||
|
||||
if context_docs:
|
||||
citation_processor = CitationProcessorGraph(
|
||||
context_docs=context_docs,
|
||||
)
|
||||
else:
|
||||
replace_citations = False
|
||||
|
||||
for message in llm.stream(
|
||||
prompt, timeout_override=timeout_override, max_tokens=max_tokens
|
||||
prompt,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
@@ -50,19 +90,153 @@ def stream_llm_answer(
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
event_name,
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=agent_answer_level,
|
||||
level_question_num=agent_answer_question_num,
|
||||
answer_type=agent_answer_type,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if answer_piece == "message_delta":
|
||||
if ind is None:
|
||||
raise ValueError("index is required when answer_piece is message_delta")
|
||||
|
||||
if replace_citations:
|
||||
processed_token = citation_processor.process_token(content)
|
||||
|
||||
if isinstance(processed_token, tuple):
|
||||
content = processed_token[0]
|
||||
citation_infos.extend(processed_token[1])
|
||||
elif isinstance(processed_token, str):
|
||||
content = processed_token
|
||||
else:
|
||||
continue
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=content, type="message_delta"),
|
||||
writer,
|
||||
)
|
||||
|
||||
elif answer_piece == "reasoning_delta":
|
||||
if ind is None:
|
||||
raise ValueError(
|
||||
"index is required when answer_piece is reasoning_delta"
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
ReasoningDelta(reasoning=content, type="reasoning_delta"),
|
||||
writer,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid answer piece: {answer_piece}")
|
||||
|
||||
end_stream_token = datetime.now()
|
||||
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
response.append(content)
|
||||
|
||||
return response, dispatch_timings
|
||||
return response, dispatch_timings, citation_infos
|
||||
|
||||
|
||||
def invoke_llm_json(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
schema: Type[SchemaType],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
|
||||
or []
|
||||
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
|
||||
|
||||
response_content = str(
|
||||
llm.invoke(
|
||||
prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
**cast(
|
||||
dict, {"structured_response_format": schema} if supports_json else {}
|
||||
),
|
||||
).content
|
||||
)
|
||||
|
||||
if not supports_json:
|
||||
# remove newlines as they often lead to json decoding errors
|
||||
response_content = response_content.replace("\n", " ")
|
||||
# hope the prompt is structured in a way a json is outputted...
|
||||
json_block_match = JSON_PATTERN.search(response_content)
|
||||
if json_block_match:
|
||||
response_content = json_block_match.group(1)
|
||||
else:
|
||||
first_bracket = response_content.find("{")
|
||||
last_bracket = response_content.rfind("}")
|
||||
response_content = response_content[first_bracket : last_bracket + 1]
|
||||
|
||||
return schema.model_validate_json(response_content)
|
||||
|
||||
|
||||
def get_answer_from_llm(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
timeout: int = 25,
|
||||
timeout_override: int = 5,
|
||||
max_tokens: int = 500,
|
||||
stream: bool = False,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
agent_answer_level: int = 0,
|
||||
agent_answer_question_num: int = 0,
|
||||
agent_answer_type: Literal[
|
||||
"agent_sub_answer", "agent_level_answer"
|
||||
] = "agent_level_answer",
|
||||
json_string_flag: bool = False,
|
||||
) -> str:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=prompt,
|
||||
)
|
||||
]
|
||||
|
||||
if stream:
|
||||
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
|
||||
stream_response, _, _ = run_with_timeout(
|
||||
timeout,
|
||||
lambda: stream_llm_answer(
|
||||
llm=llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=agent_answer_level,
|
||||
agent_answer_question_num=agent_answer_question_num,
|
||||
agent_answer_type=agent_answer_type,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
content = "".join(stream_response)
|
||||
else:
|
||||
llm_response = run_with_timeout(
|
||||
timeout,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
content = str(llm_response.content)
|
||||
|
||||
cleaned_response = content
|
||||
if json_string_flag:
|
||||
cleaned_response = (
|
||||
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
return cleaned_response
|
||||
|
||||
@@ -73,6 +73,7 @@ from onyx.prompts.agent_search import (
|
||||
HISTORY_CONTEXT_SUMMARY_PROMPT,
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
@@ -353,7 +354,7 @@ def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
|
||||
stream_type=StreamType.MAIN_ANSWER,
|
||||
level=level,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
write_custom_event(0, stop_event, writer)
|
||||
|
||||
|
||||
def retrieve_search_docs(
|
||||
@@ -438,9 +439,41 @@ class CustomStreamEvent(TypedDict):
|
||||
|
||||
|
||||
def write_custom_event(
|
||||
name: str, event: AnswerPacket, stream_writer: StreamWriter
|
||||
ind: int,
|
||||
event: AnswerPacket,
|
||||
stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
|
||||
# For types that are in PacketObj, wrap in Packet
|
||||
# For types like StreamStopInfo that frontend handles directly, stream directly
|
||||
if hasattr(event, "stop_reason"): # StreamStopInfo
|
||||
stream_writer(
|
||||
CustomStreamEvent(
|
||||
event="on_custom_event",
|
||||
data=event,
|
||||
name="",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Try to wrap in Packet for types that are compatible
|
||||
pass
|
||||
|
||||
try:
|
||||
stream_writer(
|
||||
CustomStreamEvent(
|
||||
event="on_custom_event",
|
||||
data=Packet(ind=ind, obj=event),
|
||||
name="",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: stream directly if Packet wrapping fails
|
||||
stream_writer(
|
||||
CustomStreamEvent(
|
||||
event="on_custom_event",
|
||||
data=event,
|
||||
name="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def relevance_from_docs(
|
||||
|
||||
39
backend/onyx/agents/agent_search/utils.py
Normal file
39
backend/onyx/agents/agent_search/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def create_citation_format_list(
|
||||
document_citations: list[InferenceSection],
|
||||
) -> list[dict[str, Any]]:
|
||||
citation_list: list[dict[str, Any]] = []
|
||||
for document_citation in document_citations:
|
||||
document_citation_dict = {
|
||||
"link": "",
|
||||
"blurb": document_citation.center_chunk.blurb,
|
||||
"content": document_citation.center_chunk.content,
|
||||
"metadata": document_citation.center_chunk.metadata,
|
||||
"updated_at": str(document_citation.center_chunk.updated_at),
|
||||
"document_id": document_citation.center_chunk.document_id,
|
||||
"source_type": "file",
|
||||
"source_links": document_citation.center_chunk.source_links,
|
||||
"match_highlights": document_citation.center_chunk.match_highlights,
|
||||
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
|
||||
}
|
||||
|
||||
citation_list.append(document_citation_dict)
|
||||
|
||||
return citation_list
|
||||
|
||||
|
||||
def create_question_prompt(
|
||||
system_prompt: str | None, human_prompt: str
|
||||
) -> list[BaseMessage]:
|
||||
return [
|
||||
SystemMessage(content=system_prompt or ""),
|
||||
HumanMessage(content=human_prompt),
|
||||
]
|
||||
@@ -1,9 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
@@ -12,12 +14,11 @@ from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dr_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
@@ -32,6 +33,7 @@ from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -68,6 +70,8 @@ class Answer:
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
research_type: ResearchType | None = None,
|
||||
research_plan: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
self._processed_stream: list[AnswerPacket] | None = None
|
||||
@@ -124,6 +128,9 @@ class Answer:
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
|
||||
kg_config_settings=get_kg_config_settings(),
|
||||
research_type=(
|
||||
ResearchType.DEEP if use_agentic_search else ResearchType.THOUGHTFUL
|
||||
),
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
@@ -138,12 +145,10 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search and (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
and self.graph_config.inputs.persona.name.startswith("KG Beta")
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
# TODO: add toggle in UI with customizable TimeBudget
|
||||
if self.graph_config.inputs.persona:
|
||||
run_langgraph = run_dr_graph
|
||||
|
||||
elif self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_agent_search_graph
|
||||
elif (
|
||||
@@ -210,23 +215,6 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[SubQuestionKey, list[CitationInfo]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
basic_subq_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if packet.level_question_num is not None and packet.level is not None:
|
||||
citations_by_subquestion[
|
||||
SubQuestionKey(
|
||||
level=packet.level, question_num=packet.level_question_num
|
||||
)
|
||||
].append(packet)
|
||||
elif packet.level is None:
|
||||
citations_by_subquestion[basic_subq_key].append(packet)
|
||||
return citations_by_subquestion
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -13,15 +13,16 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_source_reset_task,
|
||||
)
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
@@ -31,6 +32,7 @@ from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.prompts import get_prompts_by_ids
|
||||
@@ -42,6 +44,7 @@ from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
@@ -113,6 +116,42 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
)
|
||||
|
||||
|
||||
def saved_search_docs_from_llm_docs(
|
||||
llm_docs: list[LlmDoc] | None,
|
||||
) -> list[SavedSearchDoc]:
|
||||
"""Convert LlmDoc objects to SavedSearchDoc format."""
|
||||
if not llm_docs:
|
||||
return []
|
||||
|
||||
search_docs = []
|
||||
for i, llm_doc in enumerate(llm_docs):
|
||||
# Convert LlmDoc to SearchDoc format
|
||||
# Note: Some fields need default values as they're not in LlmDoc
|
||||
search_doc = SearchDoc(
|
||||
document_id=llm_doc.document_id,
|
||||
chunk_ind=0, # Default value as LlmDoc doesn't have chunk index
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
link=llm_doc.link,
|
||||
blurb=llm_doc.blurb,
|
||||
source_type=llm_doc.source_type,
|
||||
boost=0, # Default value
|
||||
hidden=False, # Default value
|
||||
metadata=llm_doc.metadata,
|
||||
score=None, # Will be set by SavedSearchDoc
|
||||
match_highlights=llm_doc.match_highlights or [],
|
||||
updated_at=llm_doc.updated_at,
|
||||
primary_owners=None, # Default value
|
||||
secondary_owners=None, # Default value
|
||||
is_internet=False, # Default value
|
||||
)
|
||||
|
||||
# Convert SearchDoc to SavedSearchDoc
|
||||
saved_search_doc = SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
search_docs.append(saved_search_doc)
|
||||
|
||||
return search_docs
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
@@ -401,7 +440,7 @@ def process_kg_commands(
|
||||
) -> None:
|
||||
# Temporarily, until we have a draft UI for the KG Operations/Management
|
||||
# TODO: move to api endpoint once we get frontend
|
||||
if not persona_name.startswith("KG Beta"):
|
||||
if not persona_name.startswith(TMP_DRALPHA_PERSONA_NAME):
|
||||
return
|
||||
|
||||
kg_config_settings = get_kg_config_settings()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -22,6 +20,19 @@ from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -46,46 +57,6 @@ class LlmDoc(BaseModel):
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
"""None represents references to objects in the original flow. To our understanding,
|
||||
these will not be None in the packets returned from agent search.
|
||||
"""
|
||||
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level(
|
||||
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"],
|
||||
) -> dict[int, list["SubQuestionIdentifier"]]:
|
||||
"""returns a dict of level to object list (sorted by level_question_num)
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
|
||||
# organize by level, then sort ascending by question_index
|
||||
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
|
||||
|
||||
# group by level
|
||||
for k, obj in original_dict.items():
|
||||
level = k[0]
|
||||
if level not in level_dict:
|
||||
level_dict[level] = []
|
||||
level_dict[level].append(obj)
|
||||
|
||||
# for each level, sort the group
|
||||
for k2, value2 in level_dict.items():
|
||||
# we need to handle the none case due to SubQuestionIdentifier typing
|
||||
# level_question_num as int | None, even though it should never be None here.
|
||||
level_dict[k2] = sorted(
|
||||
value2,
|
||||
key=lambda x: (x.level_question_num is None, x.level_question_num),
|
||||
)
|
||||
|
||||
# sort by level
|
||||
sorted_dict = OrderedDict(sorted(level_dict.items()))
|
||||
return sorted_dict
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
||||
rephrased_query: str | None = None
|
||||
@@ -164,11 +135,6 @@ class OnyxAnswerPiece(BaseModel):
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(SubQuestionIdentifier):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
@@ -388,7 +354,21 @@ AgentSearchPacket = Union[
|
||||
]
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
||||
AnswerQuestionPossibleReturn
|
||||
| AgentSearchPacket
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| MessageStart
|
||||
| MessageDelta
|
||||
| SectionEnd
|
||||
| ReasoningStart
|
||||
| ReasoningDelta
|
||||
| SearchToolStart
|
||||
| SearchToolDelta
|
||||
| OnyxAnswerPiece
|
||||
| CitationStart
|
||||
| CitationDelta
|
||||
| OverallStop
|
||||
)
|
||||
|
||||
|
||||
@@ -402,12 +382,12 @@ ResponsePart = (
|
||||
| AgentSearchPacket
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerPacket]
|
||||
AnswerStream = Iterator[Packet]
|
||||
|
||||
|
||||
class AnswerPostInfo(BaseModel):
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
rephrased_query: str | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
COMMON_TOOL_RESPONSE_TYPES = {
|
||||
"image": ChatFileType.IMAGE,
|
||||
"csv": ChatFileType.CSV,
|
||||
}
|
||||
|
||||
# Type definitions for packet processing
|
||||
ChatPacket = Union[
|
||||
StreamingError,
|
||||
QADocsResponse,
|
||||
LLMRelevanceFilterResponse,
|
||||
FinalUsedContextDocsResponse,
|
||||
ChatMessageDetail,
|
||||
AllCitations,
|
||||
CitationInfo,
|
||||
FileChatDisplay,
|
||||
CustomToolResponse,
|
||||
MessageResponseIDInfo,
|
||||
MessageSpecificCitations,
|
||||
AgenticMessageResponseIDInfo,
|
||||
StreamStopInfo,
|
||||
AgentSearchPacket,
|
||||
UserKnowledgeFilePacket,
|
||||
Packet,
|
||||
]
|
||||
|
||||
|
||||
def process_streamed_packets(
|
||||
answer_processed_output: AnswerStream,
|
||||
) -> Generator[ChatPacket, None, None]:
|
||||
"""Process the streamed output from the answer and yield chat packets."""
|
||||
|
||||
last_index = 0
|
||||
|
||||
for packet in answer_processed_output:
|
||||
if isinstance(packet, Packet):
|
||||
if packet.ind > last_index:
|
||||
last_index = packet.ind
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
# Yield STOP packet to indicate streaming is complete
|
||||
yield Packet(ind=last_index, obj=OverallStop())
|
||||
164
backend/onyx/chat/packet_proccessing/tool_processing.py
Normal file
164
backend/onyx/chat/packet_proccessing/tool_processing.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.models import (
|
||||
InternetSearchResponseSummary,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.utils import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def handle_search_tool_response_summary(
|
||||
current_ind: int,
|
||||
search_response: SearchResponseSummary,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
is_extended: bool,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> Generator[Packet, None, tuple[list[DbSearchDoc], list[int] | None]]:
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(search_response.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SearchToolDelta(
|
||||
documents=response_docs,
|
||||
),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
return reference_db_search_docs, dropped_inds
|
||||
|
||||
|
||||
def handle_internet_search_tool_response(
|
||||
current_ind: int,
|
||||
internet_search_response: InternetSearchResponseSummary,
|
||||
) -> Generator[Packet, None, list[DbSearchDoc]]:
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SearchToolDelta(
|
||||
documents=response_docs,
|
||||
),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
return reference_db_search_docs
|
||||
|
||||
|
||||
def handle_image_generation_tool_response(
|
||||
current_ind: int,
|
||||
img_generation_responses: list[ImageGenerationResponse],
|
||||
) -> Generator[Packet, None, None]:
|
||||
|
||||
# Save files and get file IDs
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_responses if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_responses if img.image_data
|
||||
],
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=ImageGenerationToolDelta(
|
||||
images=[
|
||||
{
|
||||
"id": str(file_id),
|
||||
"url": "", # URL will be constructed by frontend
|
||||
"prompt": img.revised_prompt,
|
||||
}
|
||||
for file_id, img in zip(file_ids, img_generation_responses)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Emit ImageToolEnd packet with file information
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
@@ -17,30 +16,25 @@ from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.chat_utils import process_kg_commands
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentMessageIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import ChatPacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -54,22 +48,15 @@ from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
inference_sections_from_ids,
|
||||
)
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import attach_files_to_chat_message
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_db_search_doc_by_id
|
||||
@@ -77,7 +64,6 @@ from onyx.db.chat import get_doc_query_identifiers_from_model
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
@@ -88,15 +74,12 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -107,50 +90,20 @@ from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from onyx.tools.tool_constructor import InternetSearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.models import (
|
||||
InternetSearchResponseSummary,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.utils import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -201,113 +154,6 @@ def _translate_citations(
|
||||
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
||||
|
||||
|
||||
def _handle_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_summary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_num = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_num = packet.level, packet.level_question_num
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_summary.rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=response_summary.predicted_flow,
|
||||
predicted_search=response_summary.predicted_search,
|
||||
applied_source_filters=response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_summary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponseSummary, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.INTERNET,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
tools: list[Tool],
|
||||
@@ -392,136 +238,9 @@ def _get_persona_for_chat_session(
|
||||
return persona
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| OnyxAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
| UserKnowledgeFilePacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
def _process_tool_response(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
user_file_files: list[UserFile] | None,
|
||||
user_files: list[InMemoryChatFile] | None,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=bool(retrieval_options and retrieval_options.dedupe_docs),
|
||||
user_files=[],
|
||||
loaded_user_files=[],
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning("No reference docs found for relevance filtering")
|
||||
return info_by_subq
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_response if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
response_type = custom_tool_response.response_type
|
||||
if response_type in COMMON_TOOL_RESPONSE_TYPES:
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=file_type)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
return info_by_subq
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -561,6 +280,7 @@ def stream_chat_message_objects(
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
llm: LLM
|
||||
answer: Answer
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
@@ -845,6 +565,18 @@ def stream_chat_message_objects(
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage:
|
||||
|
||||
is_kg_beta = parent_message.chat_session.persona.name.startswith(
|
||||
TMP_DRALPHA_PERSONA_NAME
|
||||
)
|
||||
is_basic_search = tool_call and tool_call.tool_name == SearchTool._NAME
|
||||
is_agentic_overwrite = new_msg_req.use_agentic_search and not (
|
||||
is_kg_beta and is_basic_search
|
||||
)
|
||||
|
||||
if is_kg_beta:
|
||||
is_agentic_overwrite = False
|
||||
|
||||
return create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=(
|
||||
@@ -867,11 +599,9 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
is_agentic=is_agentic_overwrite,
|
||||
)
|
||||
|
||||
partial_response = create_response
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
@@ -983,7 +713,6 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
@@ -1013,41 +742,10 @@ def stream_chat_message_objects(
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
# Process streamed packets using the new packet processing module
|
||||
yield from process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
info_by_subq = yield from _process_tool_response(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
info_by_subq=info_by_subq,
|
||||
retrieval_options=retrieval_options,
|
||||
user_file_files=user_file_models,
|
||||
user_files=in_memory_user_files,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
yield packet
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
refined_answer_improvement = packet.refined_answer_improvement
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if packet.level is not None
|
||||
and packet.level_question_num is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
@@ -1083,17 +781,6 @@ def stream_chat_message_objects(
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
yield from _post_llm_answer_processing(
|
||||
answer=answer,
|
||||
info_by_subq=info_by_subq,
|
||||
tool_dict=tool_dict,
|
||||
partial_response=partial_response,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
)
|
||||
|
||||
|
||||
def _post_llm_answer_processing(
|
||||
answer: Answer,
|
||||
@@ -1103,7 +790,6 @@ def _post_llm_answer_processing(
|
||||
llm_tokenizer_encode_func: Callable[[str], list[int]],
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
refined_answer_improvement: bool | None,
|
||||
) -> Generator[ChatPacket, None, None]:
|
||||
"""
|
||||
Stores messages in the db and yields some final packets to the frontend
|
||||
@@ -1115,20 +801,6 @@ def _post_llm_answer_processing(
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for subq_key in subq_citations:
|
||||
info = info_by_subq[subq_key]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[subq_key],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[subq_key])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
@@ -1144,9 +816,7 @@ def _post_llm_answer_processing(
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
|
||||
),
|
||||
rephrased_query=info.rephrased_query,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
files=info.ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
@@ -1205,7 +875,6 @@ def _post_llm_answer_processing(
|
||||
else None
|
||||
),
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=True,
|
||||
)
|
||||
agentic_message_ids.append(
|
||||
|
||||
@@ -3,12 +3,12 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
from onyx.prompts.constants import TRIPLE_BACKTICK
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -172,3 +172,151 @@ class CitationProcessor:
|
||||
)
|
||||
|
||||
return final_processed_str, final_citation_info
|
||||
|
||||
|
||||
class CitationProcessorGraph:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs # list of docs in the order the LLM sees
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.stop_stream = stop_stream
|
||||
|
||||
self.llm_out = "" # entire output so far
|
||||
self.curr_segment = "" # tokens held for citation processing
|
||||
self.hold = "" # tokens held for stop token processing
|
||||
|
||||
self.recent_cited_documents: set[str] = set() # docs recently cited
|
||||
self.cited_documents: set[str] = set() # docs cited in the entire stream
|
||||
self.non_citation_count = 0
|
||||
|
||||
# '[', '[[', '[1', '[[1', '[1,', '[1, ', '[1,2', '[1, 2,', etc.
|
||||
self.possible_citation_pattern = re.compile(r"(\[+(?:\d+,? ?)*$)")
|
||||
|
||||
# group 1: '[[1]]', [[2]], etc.
|
||||
# group 2: '[1]', '[1, 2]', '[1,2,16]', etc.
|
||||
self.citation_pattern = re.compile(r"(\[\[\d+\]\])|(\[\d+(?:, ?\d+)*\])")
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> str | tuple[str, list[CitationInfo]] | None:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return None
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return None
|
||||
token = next_hold
|
||||
self.hold = ""
|
||||
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
pass
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_matches = list(self.citation_pattern.finditer(self.curr_segment))
|
||||
possible_citation_found = bool(
|
||||
re.search(self.possible_citation_pattern, self.curr_segment)
|
||||
)
|
||||
|
||||
result = ""
|
||||
if citation_matches and not in_code_block(self.llm_out):
|
||||
match_idx = 0
|
||||
citation_infos = []
|
||||
for match in citation_matches:
|
||||
match_span = match.span()
|
||||
|
||||
# add stuff before/between the matches
|
||||
intermatch_str = self.curr_segment[match_idx : match_span[0]]
|
||||
self.non_citation_count += len(intermatch_str)
|
||||
match_idx = match_span[1]
|
||||
result += intermatch_str
|
||||
|
||||
# reset recent citations if no citations found for a while
|
||||
if self.non_citation_count > 5:
|
||||
self.recent_cited_documents.clear()
|
||||
|
||||
# process the citation string and emit citation info
|
||||
res, citation_info = self.process_citation(match)
|
||||
result += res
|
||||
citation_infos.extend(citation_info)
|
||||
self.non_citation_count = 0
|
||||
|
||||
# leftover could be part of next citation
|
||||
self.curr_segment = self.curr_segment[match_idx:]
|
||||
self.non_citation_count = len(self.curr_segment)
|
||||
|
||||
return result, citation_infos
|
||||
|
||||
# hold onto the current segment if potential citations found, otherwise stream
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.non_citation_count += len(self.curr_segment)
|
||||
self.curr_segment = ""
|
||||
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def process_citation(self, match: re.Match) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return the citation string and the
|
||||
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]', etc.
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', etc.
|
||||
formatted = match.lastindex == 1 # True means already in the form '[[1]]'
|
||||
|
||||
final_processed_str = ""
|
||||
final_citation_info: list[CitationInfo] = []
|
||||
|
||||
# process the citation_str
|
||||
citation_content = citation_str[2:-2] if formatted else citation_str[1:-1]
|
||||
for num in (int(num) for num in citation_content.split(",")):
|
||||
# keep invalid citations as is
|
||||
if not (1 <= num <= self.max_citation_num):
|
||||
final_processed_str += f"[[{num}]]" if formatted else f"[{num}]"
|
||||
continue
|
||||
|
||||
# translate the citation number of the LLM to what the user sees
|
||||
# should always be in the display_doc_order_dict. But check anyways
|
||||
context_llm_doc = self.context_docs[num - 1]
|
||||
llm_docid = context_llm_doc.document_id
|
||||
|
||||
# skip citations of the same work if cited recently
|
||||
if llm_docid in self.recent_cited_documents:
|
||||
continue
|
||||
self.recent_cited_documents.add(llm_docid)
|
||||
|
||||
# format the citation string
|
||||
# if formatted:
|
||||
# final_processed_str += f"[[{num}]]({link})"
|
||||
# else:
|
||||
link = context_llm_doc.link or ""
|
||||
final_processed_str += f"[[{num}]]({link})"
|
||||
|
||||
# create the citation info
|
||||
if llm_docid not in self.cited_documents:
|
||||
self.cited_documents.add(llm_docid)
|
||||
final_citation_info.append(
|
||||
CitationInfo(
|
||||
citation_num=num,
|
||||
document_id=llm_docid,
|
||||
)
|
||||
)
|
||||
|
||||
return final_processed_str, final_citation_info
|
||||
|
||||
@@ -3,6 +3,7 @@ import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -138,6 +139,8 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Onyx APIs without specifying a source type
|
||||
@@ -512,3 +515,57 @@ else:
|
||||
class OnyxCallTypes(str, Enum):
|
||||
FIREFLIES = "FIREFLIES"
|
||||
GONG = "GONG"
|
||||
|
||||
|
||||
# TODO: this should be stored likely in database
|
||||
DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
# Special case, document passed in via Onyx APIs without specifying a source type
|
||||
DocumentSource.INGESTION_API: "ingestion_api",
|
||||
DocumentSource.SLACK: "slack channels",
|
||||
DocumentSource.WEB: "web pages",
|
||||
DocumentSource.GOOGLE_DRIVE: "google drive documents (docs, sheets, etc.)",
|
||||
DocumentSource.GMAIL: "email messages",
|
||||
DocumentSource.REQUESTTRACKER: "requesttracker",
|
||||
DocumentSource.GITHUB: "github data",
|
||||
DocumentSource.GITBOOK: "gitbook data",
|
||||
DocumentSource.GITLAB: "gitlab data",
|
||||
DocumentSource.GURU: "guru data",
|
||||
DocumentSource.BOOKSTACK: "bookstack data",
|
||||
DocumentSource.CONFLUENCE: "confluence data (pages, spaces, etc.)",
|
||||
DocumentSource.JIRA: "jira data (issues, tickets, projects, etc.)",
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
DocumentSource.ZULIP: "zulip data",
|
||||
DocumentSource.LINEAR: "linear data - project management tool, including tickets etc.",
|
||||
DocumentSource.HUBSPOT: "hubspot data - CRM and marketing automation data",
|
||||
DocumentSource.DOCUMENT360: "document360 data",
|
||||
DocumentSource.GONG: "gong - call transcripts",
|
||||
DocumentSource.GOOGLE_SITES: "google_sites - websites",
|
||||
DocumentSource.ZENDESK: "zendesk - customer support data",
|
||||
DocumentSource.LOOPIO: "loopio - rfp data",
|
||||
DocumentSource.DROPBOX: "dropbox - files",
|
||||
DocumentSource.SHAREPOINT: "sharepoint - files",
|
||||
DocumentSource.TEAMS: "teams - chat and collaboration",
|
||||
DocumentSource.SALESFORCE: "salesforce - CRM data",
|
||||
DocumentSource.DISCOURSE: "discourse - discussion forums",
|
||||
DocumentSource.AXERO: "axero - employee engagement data",
|
||||
DocumentSource.CLICKUP: "clickup - project management tool",
|
||||
DocumentSource.MEDIAWIKI: "mediawiki - wiki data",
|
||||
DocumentSource.WIKIPEDIA: "wikipedia - encyclopedia data",
|
||||
DocumentSource.ASANA: "asana",
|
||||
DocumentSource.S3: "s3",
|
||||
DocumentSource.R2: "r2",
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: "google_cloud_storage - cloud storage",
|
||||
DocumentSource.OCI_STORAGE: "oci_storage - cloud storage",
|
||||
DocumentSource.XENFORO: "xenforo - forum data",
|
||||
DocumentSource.DISCORD: "discord - chat and collaboration",
|
||||
DocumentSource.FRESHDESK: "freshdesk - customer support data",
|
||||
DocumentSource.FIREFLIES: "fireflies - call transcripts",
|
||||
DocumentSource.EGNYTE: "egnyte - files",
|
||||
DocumentSource.AIRTABLE: "airtable - database",
|
||||
DocumentSource.HIGHSPOT: "highspot - CRM data",
|
||||
DocumentSource.IMAP: "imap - email data",
|
||||
}
|
||||
|
||||
@@ -140,3 +140,5 @@ KG_MAX_SEARCH_DOCUMENTS: int = int(os.environ.get("KG_MAX_SEARCH_DOCUMENTS", "15
|
||||
KG_MAX_DECOMPOSITION_SEGMENTS: int = int(
|
||||
os.environ.get("KG_MAX_DECOMPOSITION_SEGMENTS", "10")
|
||||
)
|
||||
KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
|
||||
to answer questions"
|
||||
|
||||
@@ -378,6 +378,11 @@ class SavedSearchDoc(SearchDoc):
|
||||
search_doc_data["score"] = search_doc_data.get("score") or 0.0
|
||||
return cls(**search_doc_data, db_doc_id=db_doc_id)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SavedSearchDoc":
|
||||
"""Create SavedSearchDoc from serialized dictionary data (e.g., from database JSON)"""
|
||||
return cls(**data)
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, SavedSearchDoc):
|
||||
return NotImplemented
|
||||
|
||||
@@ -19,10 +19,12 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_citation_format_list
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
@@ -41,12 +43,14 @@ from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_best_persona_id_for_user
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -55,12 +59,211 @@ from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import SubQueryDetail
|
||||
from onyx.server.query_and_chat.models import SubQuestionDetail
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import EndStepPacketList
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
|
||||
|
||||
|
||||
def create_message_packets(
|
||||
message_text: str, final_documents: list[SavedSearchDoc] | None, step_nr: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=MessageStart(
|
||||
content="",
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=MessageDelta(
|
||||
type="message_delta",
|
||||
content=message_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_citation_packets(
|
||||
citation_info_list: list[CitationInfo], step_nr: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=CitationStart(
|
||||
type="citation_start",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=CitationDelta(
|
||||
type="citation_delta",
|
||||
citations=citation_info_list,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_reasoning_packets(reasoning_text: str, step_nr: int) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=ReasoningStart(
|
||||
type="reasoning_start",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=ReasoningDelta(
|
||||
type="reasoning_delta",
|
||||
reasoning=reasoning_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_image_generation_packets(
|
||||
images: list[dict[str, str]] | None, step_nr: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=ImageGenerationToolDelta(
|
||||
type="image_generation_tool_delta", images=images
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_search_packets(
|
||||
search_queries: list[str],
|
||||
saved_search_docs: list[SavedSearchDoc] | None,
|
||||
is_internet_search: bool,
|
||||
step_nr: int,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SearchToolStart(
|
||||
type="internal_search_tool_start",
|
||||
is_internet_search=is_internet_search,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta",
|
||||
queries=search_queries,
|
||||
documents=saved_search_docs,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=step_nr,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def get_chat_session_by_id(
|
||||
chat_session_id: UUID,
|
||||
@@ -550,11 +753,23 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
# stmt = stmt.options(
|
||||
# joinedload(ChatMessage.tool_call),
|
||||
# joinedload(ChatMessage.sub_questions).joinedload(
|
||||
# AgentSubQuestion.sub_queries
|
||||
# ),
|
||||
# )
|
||||
# result = db_session.scalars(stmt).unique().all()
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id == chat_session_id)
|
||||
.order_by(nullsfirst(ChatMessage.parent_message))
|
||||
)
|
||||
stmt = stmt.options(
|
||||
joinedload(ChatMessage.tool_call),
|
||||
joinedload(ChatMessage.sub_questions).joinedload(
|
||||
AgentSubQuestion.sub_queries
|
||||
),
|
||||
joinedload(ChatMessage.research_iterations).joinedload(
|
||||
ResearchAgentIteration.sub_steps
|
||||
)
|
||||
)
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
@@ -645,8 +860,9 @@ def create_new_chat_message(
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
refined_answer_improvement: bool | None = None,
|
||||
is_agentic: bool = False,
|
||||
research_type: ResearchType | None = None,
|
||||
research_plan: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
@@ -667,8 +883,9 @@ def create_new_chat_message(
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
existing_message.refined_answer_improvement = refined_answer_improvement
|
||||
existing_message.is_agentic = is_agentic
|
||||
existing_message.research_type = research_type
|
||||
existing_message.research_plan = research_plan
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
# Create new message
|
||||
@@ -687,8 +904,9 @@ def create_new_chat_message(
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=is_agentic,
|
||||
research_type=research_type,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
@@ -1032,6 +1250,160 @@ def get_retrieval_docs_from_search_docs(
|
||||
return RetrievalDocs(top_documents=top_documents)
|
||||
|
||||
|
||||
def translate_db_message_to_packets(
|
||||
chat_message: ChatMessage,
|
||||
db_session: Session,
|
||||
remove_doc_content: bool = False,
|
||||
start_step_nr: int = 1,
|
||||
) -> EndStepPacketList:
|
||||
|
||||
step_nr = start_step_nr
|
||||
packet_list: list[Packet] = []
|
||||
|
||||
# only stream out packets for assistant messages
|
||||
if chat_message.message_type == MessageType.ASSISTANT:
|
||||
|
||||
citations = chat_message.citations
|
||||
|
||||
# Get document IDs from SearchDoc table using citation mapping
|
||||
citation_info_list = []
|
||||
if citations:
|
||||
for citation_num, search_doc_id in citations.items():
|
||||
search_doc = get_db_search_doc_by_id(search_doc_id, db_session)
|
||||
if search_doc:
|
||||
citation_info_list.append(
|
||||
CitationInfo(
|
||||
citation_num=citation_num,
|
||||
document_id=search_doc.document_id,
|
||||
)
|
||||
)
|
||||
|
||||
if chat_message.research_type in [ResearchType.THOUGHTFUL, ResearchType.DEEP]:
|
||||
research_iterations = sorted(
|
||||
chat_message.research_iterations, key=lambda x: x.iteration_nr
|
||||
) # sorted iterations
|
||||
for research_iteration in research_iterations:
|
||||
|
||||
if research_iteration.iteration_nr > 1:
|
||||
# first iteration does noty need to be reasoned for
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(research_iteration.reasoning, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if research_iteration.purpose:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(research_iteration.purpose, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
sub_steps = research_iteration.sub_steps
|
||||
tasks = []
|
||||
tool_call_ids = []
|
||||
cited_docs: list[SavedSearchDoc] = []
|
||||
|
||||
for sub_step in sub_steps:
|
||||
|
||||
tasks.append(sub_step.sub_step_instructions)
|
||||
tool_call_ids.append(sub_step.sub_step_tool_id)
|
||||
|
||||
sub_step_cited_docs = sub_step.cited_doc_results
|
||||
if isinstance(sub_step_cited_docs, list):
|
||||
# Convert serialized dict data back to SavedSearchDoc objects
|
||||
saved_search_docs = [
|
||||
(
|
||||
SavedSearchDoc.from_dict(doc_data)
|
||||
if isinstance(doc_data, dict)
|
||||
else doc_data
|
||||
)
|
||||
for doc_data in sub_step_cited_docs
|
||||
]
|
||||
cited_docs.extend(saved_search_docs)
|
||||
else:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if len(set(tool_call_ids)) > 1:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(_CANNOT_SHOW_STEP_RESULTS_STR, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif (
|
||||
len(sub_steps) == 0
|
||||
): # no sub steps, no tool calls. But iteration can have reasoning or purpose
|
||||
continue
|
||||
|
||||
else:
|
||||
# TODO: replace with isinstance, resolving circular imports
|
||||
tool_id = tool_call_ids[0]
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
tool_name = tool.name
|
||||
|
||||
if tool_name in ["SearchTool", "KnowledgeGraphTool"]:
|
||||
|
||||
cited_docs = cast(list[SavedSearchDoc], cited_docs)
|
||||
|
||||
packet_list.extend(
|
||||
create_search_packets(tasks, cited_docs, False, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif tool_name == "InternetSearchTool":
|
||||
cited_docs = cast(list[SavedSearchDoc], cited_docs)
|
||||
packet_list.extend(
|
||||
create_search_packets(tasks, cited_docs, True, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif tool_name == "ImageGenerationTool":
|
||||
|
||||
if len(tasks) > 1:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
else:
|
||||
images = cited_docs[0]
|
||||
packet_list.extend(
|
||||
create_image_generation_packets(images, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool name: {tool_name}")
|
||||
|
||||
packet_list.extend(
|
||||
create_message_packets(
|
||||
message_text=chat_message.message,
|
||||
final_documents=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in chat_message.search_docs
|
||||
],
|
||||
step_nr=step_nr,
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
packet_list.extend(create_citation_packets(citation_info_list, step_nr))
|
||||
|
||||
step_nr += 1
|
||||
|
||||
packet_list.append(Packet(ind=step_nr, obj=OverallStop()))
|
||||
|
||||
return EndStepPacketList(
|
||||
end_step_nr=step_nr,
|
||||
packet_list=packet_list,
|
||||
)
|
||||
|
||||
|
||||
def translate_db_message_to_chat_message_detail(
|
||||
chat_message: ChatMessage,
|
||||
remove_doc_content: bool = False,
|
||||
@@ -1061,11 +1433,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
),
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
sub_questions=translate_db_sub_questions_to_server_objects(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
@@ -1111,27 +1478,6 @@ def log_agent_sub_question_results(
|
||||
primary_message_id: int | None,
|
||||
sub_question_answer_results: list[SubQuestionAnswerResults],
|
||||
) -> None:
|
||||
def _create_citation_format_list(
|
||||
document_citations: list[InferenceSection],
|
||||
) -> list[dict[str, Any]]:
|
||||
citation_list: list[dict[str, Any]] = []
|
||||
for document_citation in document_citations:
|
||||
document_citation_dict = {
|
||||
"link": "",
|
||||
"blurb": document_citation.center_chunk.blurb,
|
||||
"content": document_citation.center_chunk.content,
|
||||
"metadata": document_citation.center_chunk.metadata,
|
||||
"updated_at": str(document_citation.center_chunk.updated_at),
|
||||
"document_id": document_citation.center_chunk.document_id,
|
||||
"source_type": "file",
|
||||
"source_links": document_citation.center_chunk.source_links,
|
||||
"match_highlights": document_citation.center_chunk.match_highlights,
|
||||
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
|
||||
}
|
||||
|
||||
citation_list.append(document_citation_dict)
|
||||
|
||||
return citation_list
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
@@ -1141,7 +1487,7 @@ def log_agent_sub_question_results(
|
||||
]
|
||||
sub_question = sub_question_answer_result.question
|
||||
sub_answer = sub_question_answer_result.answer
|
||||
sub_document_results = _create_citation_format_list(
|
||||
sub_document_results = create_citation_format_list(
|
||||
sub_question_answer_result.context_documents
|
||||
)
|
||||
|
||||
@@ -1198,3 +1544,58 @@ def update_chat_session_updated_at_timestamp(
|
||||
.values(time_updated=func.now())
|
||||
)
|
||||
# No commit - the caller is responsible for committing the transaction
|
||||
|
||||
|
||||
def create_search_doc_from_inference_section(
|
||||
inference_section: InferenceSection,
|
||||
is_internet: bool,
|
||||
db_session: Session,
|
||||
score: float = 0.0,
|
||||
is_relevant: bool | None = None,
|
||||
relevance_explanation: str | None = None,
|
||||
commit: bool = False,
|
||||
) -> SearchDoc:
|
||||
"""Create a SearchDoc in the database from an InferenceSection."""
|
||||
|
||||
db_search_doc = SearchDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
chunk_ind=inference_section.center_chunk.chunk_id,
|
||||
semantic_id=inference_section.center_chunk.semantic_identifier,
|
||||
link=(
|
||||
inference_section.center_chunk.source_links.get(0)
|
||||
if inference_section.center_chunk.source_links
|
||||
else None
|
||||
),
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
boost=inference_section.center_chunk.boost,
|
||||
hidden=inference_section.center_chunk.hidden,
|
||||
doc_metadata=inference_section.center_chunk.metadata,
|
||||
score=score,
|
||||
is_relevant=is_relevant,
|
||||
relevance_explanation=relevance_explanation,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
primary_owners=inference_section.center_chunk.primary_owners or [],
|
||||
secondary_owners=inference_section.center_chunk.secondary_owners or [],
|
||||
is_internet=is_internet,
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
def create_search_doc_from_saved_search_doc(
|
||||
saved_search_doc: SavedSearchDoc,
|
||||
) -> SearchDoc:
|
||||
"""Convert SavedSearchDoc to SearchDoc by excluding the additional fields"""
|
||||
data = saved_search_doc.model_dump()
|
||||
# Remove the fields that are specific to SavedSearchDoc
|
||||
data.pop("db_doc_id", None)
|
||||
# Keep score since SearchDoc has it as an optional field
|
||||
return SearchDoc(**data)
|
||||
|
||||
@@ -82,6 +82,8 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -677,8 +679,8 @@ class KGEntityType(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
grounded_source_name: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=False
|
||||
)
|
||||
|
||||
entity_values: Mapped[list[str]] = mapped_column(
|
||||
@@ -2139,12 +2141,26 @@ class ChatMessage(Base):
|
||||
order_by="(AgentSubQuestion.level, AgentSubQuestion.level_question_num)",
|
||||
)
|
||||
|
||||
research_iterations: Mapped[list["ResearchAgentIteration"]] = relationship(
|
||||
"ResearchAgentIteration",
|
||||
foreign_keys="ResearchAgentIteration.primary_question_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
research_type: Mapped[ResearchType] = mapped_column(
|
||||
Enum(ResearchType, native_enum=False), nullable=True
|
||||
)
|
||||
research_plan: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
research_answer_purpose: Mapped[ResearchAnswerPurpose] = mapped_column(
|
||||
Enum(ResearchAnswerPurpose, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
"""For organizing chat sessions"""
|
||||
@@ -3343,3 +3359,71 @@ class TenantAnonymousUserPath(Base):
|
||||
anonymous_user_path: Mapped[str] = mapped_column(
|
||||
String, nullable=False, unique=True
|
||||
)
|
||||
|
||||
|
||||
class ResearchAgentIteration(Base):
|
||||
__tablename__ = "research_agent_iteration"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE")
|
||||
)
|
||||
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
purpose: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
reasoning: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
primary_message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[primary_question_id],
|
||||
back_populates="research_iterations",
|
||||
)
|
||||
|
||||
sub_steps: Mapped[list["ResearchAgentIterationSubStep"]] = relationship(
|
||||
"ResearchAgentIterationSubStep",
|
||||
primaryjoin=(
|
||||
"and_("
|
||||
"ResearchAgentIteration.primary_question_id == ResearchAgentIterationSubStep.primary_question_id, "
|
||||
"ResearchAgentIteration.iteration_nr == ResearchAgentIterationSubStep.iteration_nr"
|
||||
")"
|
||||
),
|
||||
foreign_keys="[ResearchAgentIterationSubStep.primary_question_id, ResearchAgentIterationSubStep.iteration_nr]",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class ResearchAgentIterationSubStep(Base):
|
||||
__tablename__ = "research_agent_iteration_sub_step"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE")
|
||||
)
|
||||
parent_question_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
iteration_sub_step_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
sub_step_instructions: Mapped[str] = mapped_column(String, nullable=True)
|
||||
sub_step_tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), nullable=True)
|
||||
reasoning: Mapped[str] = mapped_column(String, nullable=True)
|
||||
sub_answer: Mapped[str] = mapped_column(String, nullable=True)
|
||||
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
claims: Mapped[list[str]] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
additional_data: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
# Relationships
|
||||
primary_message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[primary_question_id],
|
||||
)
|
||||
|
||||
parent_sub_step: Mapped["ResearchAgentIterationSubStep"] = relationship(
|
||||
"ResearchAgentIterationSubStep",
|
||||
foreign_keys=[parent_question_id],
|
||||
remote_side="ResearchAgentIterationSubStep.id",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,8 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.tools.built_in_tools import get_builtin_tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -49,9 +50,7 @@ def create_slack_channel_persona(
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
search_tool = get_search_tool(db_session)
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool not found")
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
|
||||
# create/update persona associated with the Slack channel
|
||||
persona_name = _build_persona_name(channel_name)
|
||||
|
||||
@@ -47,15 +47,15 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_classification_extraction_instructions() -> (
|
||||
dict[str, dict[str, KGEntityTypeInstructions]]
|
||||
dict[str | None, dict[str, KGEntityTypeInstructions]]
|
||||
):
|
||||
"""
|
||||
Prepare the classification instructions for the given source.
|
||||
"""
|
||||
|
||||
classification_instructions_dict: dict[str, dict[str, KGEntityTypeInstructions]] = (
|
||||
{}
|
||||
)
|
||||
classification_instructions_dict: dict[
|
||||
str | None, dict[str, KGEntityTypeInstructions]
|
||||
] = {}
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_types = get_entity_types(db_session, active=True)
|
||||
|
||||
@@ -32,9 +32,7 @@ def format_entity_id_for_models(entity_id_name: str) -> str:
|
||||
separator = entity_type = ""
|
||||
|
||||
formatted_entity_type = entity_type.strip().upper()
|
||||
formatted_entity_name = (
|
||||
entity_name.strip().replace('"', "").replace("'", "").title()
|
||||
)
|
||||
formatted_entity_name = entity_name.strip().replace('"', "").replace("'", "")
|
||||
|
||||
return f"{formatted_entity_type}{separator}{formatted_entity_name}"
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
@@ -25,6 +26,7 @@ class PreviousMessage(BaseModel):
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None
|
||||
research_answer_purpose: ResearchAnswerPurpose | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -52,6 +54,7 @@ class PreviousMessage(BaseModel):
|
||||
else None
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
research_answer_purpose=chat_message.research_answer_purpose,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -9,7 +9,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.webhook import WebhookClient
|
||||
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -50,6 +49,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
||||
1249
backend/onyx/prompts/dr_prompts.py
Normal file
1249
backend/onyx/prompts/dr_prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -669,8 +669,8 @@ that should be used to analyze each object/each source (or 'the object' that fit
|
||||
}}
|
||||
|
||||
Do not include any other text or explanations.
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_DETECTION_PROMPT = f"""
|
||||
You are an expert in generating, understanding and analyzing SQL statements.
|
||||
|
||||
@@ -773,11 +773,29 @@ Please structure your answer using <reasoning>, </reasoning>,<sql>, </sql> start
|
||||
""".strip()
|
||||
|
||||
|
||||
SIMPLE_SQL_PROMPT = f"""
|
||||
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
|
||||
between TWO ENTITIES. The table has the following structure:
|
||||
ENTITY_TABLE_DESCRIPTION = f"""\
|
||||
- Table name: entity_table
|
||||
- Columns:
|
||||
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
|
||||
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
|
||||
- entity_type (str): the type of the entity [example: ACCOUNT].
|
||||
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
|
||||
- source_document (str): the id of the document that contains the entity. Note that the combination of \
|
||||
id_name and source_document IS UNIQUE!
|
||||
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
"""
|
||||
|
||||
RELATIONSHIP_TABLE_DESCRIPTION = f"""\
|
||||
- Table name: relationship_table
|
||||
- Columns:
|
||||
- relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \
|
||||
@@ -803,17 +821,27 @@ id_name and source_document IS UNIQUE!
|
||||
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided.
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>:
|
||||
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>.
|
||||
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by '::<entity_name>' \
|
||||
in the relationship id as shown above.
|
||||
{SEPARATOR_LINE}
|
||||
---relationship_types---
|
||||
{SEPARATOR_LINE}
|
||||
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by ':<entity_name>' in the \
|
||||
relationship id as shown above..
|
||||
"""
|
||||
|
||||
|
||||
SIMPLE_SQL_PROMPT = f"""
|
||||
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
|
||||
between TWO ENTITIES. The table has the following structure:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
{RELATIONSHIP_TABLE_DESCRIPTION}
|
||||
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
@@ -936,7 +964,7 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
""".strip()
|
||||
|
||||
|
||||
# TODO: remove following before merging after enough testing
|
||||
SIMPLE_SQL_CORRECTION_PROMPT = f"""
|
||||
You are an expert in reviewing and fixing SQL statements.
|
||||
|
||||
@@ -949,7 +977,7 @@ Guidance:
|
||||
SELECT statement as well! And it needs to be in the EXACT FORM! So if a \
|
||||
conversion took place, make sure to include the conversion in the SELECT and the ORDER BY clause!
|
||||
- never should 'source_document' be in the SELECT clause! Remove if present!
|
||||
- if there are joins, they must be on entities, never sour ce documents
|
||||
- if there are joins, they must be on entities, never source documents
|
||||
- if there are joins, consider the possibility that the second entity does not exist for all examples.\
|
||||
Therefore consider using LEFT joins (or RIGHT joins) as appropriate.
|
||||
|
||||
@@ -969,26 +997,7 @@ You are an expert in generating a SQL statement that only uses ONE TABLE that ca
|
||||
and their attributes and other data. The table has the following structure:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
- Table name: entity_table
|
||||
- Columns:
|
||||
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
|
||||
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
|
||||
- entity_type (str): the type of the entity [example: ACCOUNT].
|
||||
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
|
||||
- source_document (str): the id of the document that contains the entity. Note that the combination of \
|
||||
id_name and source_document IS UNIQUE!
|
||||
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
|
||||
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
{ENTITY_TABLE_DESCRIPTION}
|
||||
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
@@ -1077,33 +1086,55 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
""".strip()
|
||||
|
||||
SIMPLE_SQL_ERROR_FIX_PROMPT = f"""
|
||||
You are an expert at fixing SQL statements. You will be provided with a SQL statement that aims to address \
|
||||
a question, but it contains an error. Your task is to fix the SQL statement, based on the error message.
|
||||
|
||||
SQL_AGGREGATION_REMOVAL_PROMPT = f"""
|
||||
You are a SQL expert. You were provided with a SQL statement that returns an aggregation, and you are \
|
||||
tasked to show the underlying objects that were aggregated. For this you need to remove the aggregate functions \
|
||||
from the SQL statement in the correct way.
|
||||
Here is the description of the table that the SQL statement is supposed to use:
|
||||
---table_description---
|
||||
|
||||
Additional rules:
|
||||
- if you see a 'select count(*)', you should NOT convert \
|
||||
that to 'select *...', but rather return the corresponding id_name, entity_type_id_name, name, and document_id. \
|
||||
As in: 'select <table, if necessary>.id_name, <table, if necessary>.entity_type_id_name, \
|
||||
<table, if necessary>.name, <table, if necessary>.document_id ...'. \
|
||||
The id_name is always the primary index, and those should be returned, along with the type (entity_type_id_name), \
|
||||
the name (name) of the objects, and the document_id (document_id) of the object.
|
||||
- Add a limit of 30 to the select statement.
|
||||
- Don't change anything else.
|
||||
- The final select statement needs obviously to be a valid SQL statement.
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
---question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the SQL statement you are supposed to remove the aggregate functions from:
|
||||
Here is the SQL statement that you should fix:
|
||||
{SEPARATOR_LINE}
|
||||
---sql_statement---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the error message that was returned:
|
||||
{SEPARATOR_LINE}
|
||||
---error_message---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Note that in the case the error states the sql statement did not return any results, it is possible that the \
|
||||
sql statement is correct, but the question is not addressable with the information in the knowledge graph. \
|
||||
If you are absolutely certain that is the case, you may return the original sql statement.
|
||||
|
||||
Here are a couple common errors that you may encounter:
|
||||
- source_document is in the SELECT clause -> remove it
|
||||
- columns used in ORDER BY must also appear in the SELECT DISTINCT clause
|
||||
- consider carefully the type of the columns you are using, especially for attributes. You may need to cast them
|
||||
- dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \
|
||||
So please use that format, particularly if you use data comparisons (>, <, ...)
|
||||
- attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \
|
||||
"attributes ->> '<attribute>' = '<attribute value>'" (or "attributes ? '<attribute>'" to check for existence).
|
||||
- if you are using joins and the sql returned no joins, make sure you are using the appropriate join type (LEFT, RIGHT, etc.) \
|
||||
it is possible that the second entity does not exist for all examples.
|
||||
- (ignore if using entity_table) if using the relationship_table and the sql returned no results, make sure you are \
|
||||
selecting the correct column! Use the available relationship types to determine whether to use the source or target entity.
|
||||
|
||||
APPROACH:
|
||||
Please think through this step by step. Please also bear in mind that the sql statement is written in postgres syntax.
|
||||
|
||||
Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---.
|
||||
|
||||
Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> start and end tags as in:
|
||||
|
||||
<reasoning>[your short step-by step thinking]</reasoning>
|
||||
<sql>[the SQL statement without the aggregate functions]</sql>
|
||||
""".strip()
|
||||
<reasoning>[think through the logic but do so extremely briefly! Not more than 3-4 sentences.]</reasoning>
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
"""
|
||||
|
||||
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT = f"""
|
||||
|
||||
43
backend/onyx/prompts/prompt_template.py
Normal file
43
backend/onyx/prompts/prompt_template.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import re
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""
|
||||
A class for building prompt templates with placeholders.
|
||||
Useful when building templates with json schemas, as {} will not work with f-strings.
|
||||
Unlike string.replace, this class will raise an error if the fields are missing.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = r"---([a-zA-Z0-9_]+)---"
|
||||
|
||||
def __init__(self, template: str, pattern: str = DEFAULT_PATTERN):
|
||||
self._pattern_str = pattern
|
||||
self._pattern = re.compile(pattern)
|
||||
self._template = template
|
||||
self._fields: set[str] = set(self._pattern.findall(template))
|
||||
|
||||
def build(self, **kwargs: str) -> str:
|
||||
"""
|
||||
Build the prompt template with the given fields.
|
||||
Will raise an error if the fields are missing.
|
||||
Will ignore fields that are not in the template.
|
||||
"""
|
||||
missing = self._fields - set(kwargs.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Missing required fields: {missing}.")
|
||||
return self._replace_fields(kwargs)
|
||||
|
||||
def partial_build(self, **kwargs: str) -> "PromptTemplate":
|
||||
"""
|
||||
Returns another PromptTemplate with the given fields replaced.
|
||||
Will ignore fields that are not in the template.
|
||||
"""
|
||||
new_template = self._replace_fields(kwargs)
|
||||
return PromptTemplate(new_template, self._pattern_str)
|
||||
|
||||
def _replace_fields(self, field_vals: dict[str, str]) -> str:
|
||||
def repl(match: re.Match) -> str:
|
||||
key = match.group(1)
|
||||
return field_vals.get(key, match.group(0))
|
||||
|
||||
return self._pattern.sub(repl, self._template)
|
||||
@@ -3,6 +3,8 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
@@ -31,12 +33,13 @@ from onyx.server.kg.models import KGConfig
|
||||
from onyx.server.kg.models import KGConfig as KGConfigAPIModel
|
||||
from onyx.server.kg.models import SourceAndEntityTypeView
|
||||
from onyx.server.kg.models import SourceStatistics
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.tools.built_in_tools import get_builtin_tool
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
_KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
|
||||
to answer questions"
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/kg")
|
||||
|
||||
|
||||
@@ -95,12 +98,9 @@ def enable_or_disable_kg(
|
||||
enable_kg(enable_req=req)
|
||||
populate_missing_default_entity_types__commit(db_session=db_session)
|
||||
|
||||
# Create or restore KG Beta persona
|
||||
|
||||
# Get the search tool
|
||||
search_tool = get_search_tool(db_session=db_session)
|
||||
if not search_tool:
|
||||
raise RuntimeError("SearchTool not found in the database.")
|
||||
# Get the search and knowledge graph tools
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
kg_tool = get_builtin_tool(db_session=db_session, tool_type=KnowledgeGraphTool)
|
||||
|
||||
# Check if we have a previously created persona
|
||||
kg_config_settings = get_kg_config_settings()
|
||||
@@ -132,8 +132,8 @@ def enable_or_disable_kg(
|
||||
is_public = len(user_ids) == 0
|
||||
|
||||
persona_request = PersonaUpsertRequest(
|
||||
name="KG Beta",
|
||||
description=_KG_BETA_ASSISTANT_DESCRIPTION,
|
||||
name=TMP_DRALPHA_PERSONA_NAME,
|
||||
description=KG_BETA_ASSISTANT_DESCRIPTION,
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
@@ -145,7 +145,7 @@ def enable_or_disable_kg(
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
prompt_ids=[0],
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_message_to_packets
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
from onyx.db.connector import create_connector
|
||||
@@ -92,6 +93,8 @@ from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.file_types import UploadMimeTypes
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
@@ -233,6 +236,24 @@ def get_chat_session(
|
||||
prefetch_tool_calls=True,
|
||||
)
|
||||
|
||||
# Convert messages to ChatMessageDetail format
|
||||
chat_message_details = [
|
||||
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
||||
]
|
||||
|
||||
simplified_packet_lists: list[list[Packet]] = []
|
||||
end_step_nr = 1
|
||||
for msg in session_messages:
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
msg_packet_object = translate_db_message_to_packets(
|
||||
msg, db_session=db_session, start_step_nr=end_step_nr
|
||||
)
|
||||
end_step_nr = msg_packet_object.end_step_nr
|
||||
msg_packet_list = msg_packet_object.packet_list
|
||||
|
||||
msg_packet_list.append(Packet(ind=end_step_nr, obj=OverallStop()))
|
||||
simplified_packet_lists.append(msg_packet_list)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
chat_session_id=session_id,
|
||||
description=chat_session.description,
|
||||
@@ -245,13 +266,13 @@ def get_chat_session(
|
||||
chat_session.persona.icon_shape if chat_session.persona else None
|
||||
),
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
||||
],
|
||||
messages=chat_message_details,
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
deleted=chat_session.deleted,
|
||||
# specifically for the Onyx Chat UI
|
||||
packets=simplified_packet_lists,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.enums import ChatSessionSharedStatus
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
|
||||
|
||||
@@ -240,11 +241,8 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
# Dict mapping citation number to db_doc_id
|
||||
citations: dict[int, int] | None = None
|
||||
sub_questions: list[SubQuestionDetail] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
is_agentic: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
@@ -274,6 +272,8 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
current_temperature_override: float | None
|
||||
deleted: bool = False
|
||||
|
||||
packets: list[list[Packet]]
|
||||
|
||||
|
||||
# This one is not used anymore
|
||||
class QueryValidationResponse(BaseModel):
|
||||
|
||||
190
backend/onyx/server/query_and_chat/streaming_models.py
Normal file
190
backend/onyx/server/query_and_chat/streaming_models.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Annotated
|
||||
from typing import Literal
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
|
||||
|
||||
class BaseObj(BaseModel):
|
||||
type: str = ""
|
||||
|
||||
|
||||
"""Basic Message Packets"""
|
||||
|
||||
|
||||
class MessageStart(BaseObj):
|
||||
type: Literal["message_start"] = "message_start"
|
||||
|
||||
# Merged set of all documents considered
|
||||
final_documents: list[SavedSearchDoc] | None
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class MessageDelta(BaseObj):
|
||||
content: str
|
||||
type: Literal["message_delta"] = "message_delta"
|
||||
|
||||
|
||||
"""Control Packets"""
|
||||
|
||||
|
||||
class OverallStop(BaseObj):
|
||||
type: Literal["stop"] = "stop"
|
||||
|
||||
|
||||
class SectionEnd(BaseObj):
|
||||
type: Literal["section_end"] = "section_end"
|
||||
|
||||
|
||||
"""Tool Packets"""
|
||||
|
||||
|
||||
class SearchToolStart(BaseObj):
|
||||
type: Literal["internal_search_tool_start"] = "internal_search_tool_start"
|
||||
|
||||
is_internet_search: bool = False
|
||||
|
||||
|
||||
class SearchToolDelta(BaseObj):
|
||||
type: Literal["internal_search_tool_delta"] = "internal_search_tool_delta"
|
||||
|
||||
queries: list[str] | None = None
|
||||
documents: list[SavedSearchDoc] | None = None
|
||||
|
||||
|
||||
class ImageGenerationToolStart(BaseObj):
|
||||
type: Literal["image_generation_tool_start"] = "image_generation_tool_start"
|
||||
|
||||
|
||||
class ImageGenerationToolDelta(BaseObj):
|
||||
type: Literal["image_generation_tool_delta"] = "image_generation_tool_delta"
|
||||
|
||||
images: list[dict[str, str]] | None = None
|
||||
|
||||
|
||||
class CustomToolStart(BaseObj):
|
||||
type: Literal["custom_tool_start"] = "custom_tool_start"
|
||||
|
||||
tool_name: str
|
||||
|
||||
|
||||
class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = "custom_tool_delta"
|
||||
|
||||
tool_name: str
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
"""Reasoning Packets"""
|
||||
|
||||
|
||||
class ReasoningStart(BaseObj):
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
class ReasoningDelta(BaseObj):
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
|
||||
reasoning: str
|
||||
|
||||
|
||||
"""Citation Packets"""
|
||||
|
||||
|
||||
class CitationStart(BaseObj):
|
||||
type: Literal["citation_start"] = "citation_start"
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
"""None represents references to objects in the original flow. To our understanding,
|
||||
these will not be None in the packets returned from agent search.
|
||||
"""
|
||||
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level(
|
||||
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"],
|
||||
) -> dict[int, list["SubQuestionIdentifier"]]:
|
||||
"""returns a dict of level to object list (sorted by level_question_num)
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
|
||||
# organize by level, then sort ascending by question_index
|
||||
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
|
||||
|
||||
# group by level
|
||||
for k, obj in original_dict.items():
|
||||
level = k[0]
|
||||
if level not in level_dict:
|
||||
level_dict[level] = []
|
||||
level_dict[level].append(obj)
|
||||
|
||||
# for each level, sort the group
|
||||
for k2, value2 in level_dict.items():
|
||||
# we need to handle the none case due to SubQuestionIdentifier typing
|
||||
# level_question_num as int | None, even though it should never be None here.
|
||||
level_dict[k2] = sorted(
|
||||
value2,
|
||||
key=lambda x: (x.level_question_num is None, x.level_question_num),
|
||||
)
|
||||
|
||||
# sort by level
|
||||
sorted_dict = OrderedDict(sorted(level_dict.items()))
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class CitationInfo(SubQuestionIdentifier):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class CitationDelta(BaseObj):
|
||||
type: Literal["citation_delta"] = "citation_delta"
|
||||
|
||||
citations: list[CitationInfo] | None = None
|
||||
|
||||
|
||||
"""Packet"""
|
||||
|
||||
# Discriminated union of all possible packet object types
|
||||
PacketObj = Annotated[
|
||||
Union[
|
||||
MessageStart,
|
||||
MessageDelta,
|
||||
OverallStop,
|
||||
SectionEnd,
|
||||
SearchToolStart,
|
||||
SearchToolDelta,
|
||||
ImageGenerationToolStart,
|
||||
ImageGenerationToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolDelta,
|
||||
ReasoningStart,
|
||||
ReasoningDelta,
|
||||
CitationStart,
|
||||
CitationDelta,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class Packet(BaseModel):
|
||||
ind: int
|
||||
obj: PacketObj
|
||||
|
||||
|
||||
class EndStepPacketList(BaseModel):
|
||||
end_step_nr: int
|
||||
packet_list: list[Packet]
|
||||
318
backend/onyx/server/query_and_chat/streaming_utils.py
Normal file
318
backend/onyx/server/query_and_chat/streaming_utils.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
|
||||
|
||||
def create_simplified_packets_for_message(
|
||||
message: ChatMessageDetail, packet_index_start: int = 0
|
||||
) -> list[Packet]:
|
||||
"""
|
||||
Convert a ChatMessageDetail into simplified streaming packets that represent
|
||||
what would have been sent during the original streaming response.
|
||||
|
||||
Args:
|
||||
message: The chat message to convert to packets
|
||||
packet_index_start: Starting index for packet numbering
|
||||
|
||||
Returns:
|
||||
List of simplified packets representing the message
|
||||
"""
|
||||
packets: list[Packet] = []
|
||||
current_index = packet_index_start
|
||||
|
||||
# Only create packets for assistant messages
|
||||
if message.message_type != MessageType.ASSISTANT:
|
||||
return packets
|
||||
|
||||
# Handle all tool-related packets in one unified block
|
||||
# Check for tool calls first, then fall back to inferred tools from context/files
|
||||
if message.tool_call:
|
||||
tool_call = message.tool_call
|
||||
|
||||
# Handle different tool types based on tool name
|
||||
if tool_call.tool_name == "run_search":
|
||||
# Handle search tools - create search tool packets
|
||||
# Use context docs if available, otherwise use tool result
|
||||
if message.context_docs and message.context_docs.top_documents:
|
||||
search_docs = message.context_docs.top_documents
|
||||
|
||||
# Start search tool
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SearchToolStart(),
|
||||
)
|
||||
)
|
||||
|
||||
# Include queries and documents in the delta
|
||||
if message.rephrased_query and message.rephrased_query.strip():
|
||||
queries = [str(message.rephrased_query)]
|
||||
else:
|
||||
queries = [message.message]
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=search_docs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# End search tool
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
current_index += 1
|
||||
|
||||
elif tool_call.tool_name == "run_image_generation":
|
||||
# Handle image generation tools - create image generation packets
|
||||
# Use files if available, otherwise create from tool result
|
||||
if message.files:
|
||||
image_files = [
|
||||
f for f in message.files if f["type"] == ChatFileType.IMAGE
|
||||
]
|
||||
if image_files:
|
||||
# Start image tool
|
||||
image_tool_start = ImageGenerationToolStart()
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_start))
|
||||
|
||||
# Send images via tool delta
|
||||
images = []
|
||||
for file in image_files:
|
||||
images.append(
|
||||
{
|
||||
"id": file["id"],
|
||||
"url": "", # URL will be constructed by frontend
|
||||
"prompt": file.get("name") or "Generated image",
|
||||
}
|
||||
)
|
||||
|
||||
image_tool_delta = ImageGenerationToolDelta(images=images)
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_delta))
|
||||
|
||||
# End image tool
|
||||
image_tool_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_end))
|
||||
current_index += 1
|
||||
|
||||
elif tool_call.tool_name == "run_internet_search":
|
||||
# Internet search tools return document data, but should be treated as custom tools
|
||||
# for packet purposes since they have a different data structure
|
||||
# Start custom tool
|
||||
custom_tool_start = CustomToolStart(tool_name=tool_call.tool_name)
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_start))
|
||||
|
||||
# Send internet search results as custom tool data
|
||||
custom_tool_delta = CustomToolDelta(
|
||||
tool_name=tool_call.tool_name,
|
||||
response_type="json",
|
||||
data=tool_call.tool_result,
|
||||
file_ids=None,
|
||||
)
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_delta))
|
||||
|
||||
# End custom tool
|
||||
custom_tool_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_end))
|
||||
current_index += 1
|
||||
|
||||
else:
|
||||
# Handle custom tools and any other tool types
|
||||
# Start custom tool
|
||||
custom_tool_start = CustomToolStart(tool_name=tool_call.tool_name)
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_start))
|
||||
|
||||
# Determine response type and data from tool result
|
||||
response_type = "json" # default
|
||||
data = None
|
||||
file_ids = None
|
||||
|
||||
if tool_call.tool_result:
|
||||
# Check if it's a custom tool call summary (most common case)
|
||||
if isinstance(tool_call.tool_result, dict):
|
||||
# Try to extract response_type if it's structured like CustomToolCallSummary
|
||||
if "response_type" in tool_call.tool_result:
|
||||
response_type = tool_call.tool_result["response_type"]
|
||||
tool_result = tool_call.tool_result.get("tool_result")
|
||||
|
||||
# Handle file-based responses
|
||||
if isinstance(tool_result, dict) and "file_ids" in tool_result:
|
||||
file_ids = tool_result["file_ids"]
|
||||
else:
|
||||
data = tool_result
|
||||
else:
|
||||
# Plain dict response
|
||||
data = tool_call.tool_result
|
||||
else:
|
||||
# Non-dict response (string, number, etc.)
|
||||
data = tool_call.tool_result
|
||||
|
||||
# Send tool response via tool delta
|
||||
custom_tool_delta = CustomToolDelta(
|
||||
tool_name=tool_call.tool_name,
|
||||
response_type=response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_delta))
|
||||
|
||||
# End custom tool
|
||||
custom_tool_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=custom_tool_end))
|
||||
current_index += 1
|
||||
|
||||
# Fallback handling for when there's no explicit tool_call but we have tool-related data
|
||||
elif message.context_docs and message.context_docs.top_documents:
|
||||
# Handle search results without explicit tool call (legacy support)
|
||||
search_docs = message.context_docs.top_documents
|
||||
|
||||
# Start search tool
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SearchToolStart(),
|
||||
)
|
||||
)
|
||||
|
||||
# Include queries and documents in the delta
|
||||
if message.rephrased_query and message.rephrased_query.strip():
|
||||
queries = [str(message.rephrased_query)]
|
||||
else:
|
||||
queries = [message.message]
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=search_docs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# End search tool
|
||||
packets.append(
|
||||
Packet(
|
||||
ind=current_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
current_index += 1
|
||||
|
||||
# Handle image files without explicit tool call (legacy support)
|
||||
if message.files:
|
||||
image_files = [f for f in message.files if f["type"] == ChatFileType.IMAGE]
|
||||
if image_files and not message.tool_call:
|
||||
# Only create image packets if there's no tool call that might have handled them
|
||||
# Start image tool
|
||||
image_tool_start = ImageGenerationToolStart()
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_start))
|
||||
|
||||
# Send images via tool delta
|
||||
images = []
|
||||
for file in image_files:
|
||||
images.append(
|
||||
{
|
||||
"id": file["id"],
|
||||
"url": "", # URL will be constructed by frontend
|
||||
"prompt": file.get("name") or "Generated image",
|
||||
}
|
||||
)
|
||||
|
||||
image_tool_delta = ImageGenerationToolDelta(images=images)
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_delta))
|
||||
|
||||
# End image tool
|
||||
image_tool_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=image_tool_end))
|
||||
current_index += 1
|
||||
|
||||
# Create Citation packets if there are citations
|
||||
if message.citations:
|
||||
# Start citation flow
|
||||
citation_start = CitationStart()
|
||||
packets.append(Packet(ind=current_index, obj=citation_start))
|
||||
|
||||
# Create citation data
|
||||
# Convert dict[int, int] to list[StreamingCitation] format
|
||||
citations_list: list[CitationInfo] = []
|
||||
for citation_num, doc_id in message.citations.items():
|
||||
citation = CitationInfo(citation_num=citation_num, document_id=str(doc_id))
|
||||
citations_list.append(citation)
|
||||
|
||||
# Send citations via citation delta
|
||||
citation_delta = CitationDelta(citations=citations_list)
|
||||
packets.append(Packet(ind=current_index, obj=citation_delta))
|
||||
|
||||
# End citation flow
|
||||
citation_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=citation_end))
|
||||
current_index += 1
|
||||
|
||||
# Create MESSAGE_START packet
|
||||
message_start = MessageStart(
|
||||
content="",
|
||||
final_documents=(
|
||||
message.context_docs.top_documents if message.context_docs else None
|
||||
),
|
||||
)
|
||||
packets.append(Packet(ind=current_index, obj=message_start))
|
||||
|
||||
# Create MESSAGE_DELTA packet with the full message content
|
||||
# In a real streaming scenario, this would be broken into multiple deltas
|
||||
if message.message:
|
||||
message_delta = MessageDelta(content=message.message)
|
||||
packets.append(Packet(ind=current_index, obj=message_delta))
|
||||
|
||||
# Create MESSAGE_END packet
|
||||
message_end = SectionEnd()
|
||||
packets.append(Packet(ind=current_index, obj=message_end))
|
||||
current_index += 1
|
||||
|
||||
# Create STOP packet
|
||||
stop = OverallStop()
|
||||
packets.append(Packet(ind=current_index, obj=stop))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_simplified_packets_for_session(
|
||||
messages: list[ChatMessageDetail],
|
||||
) -> list[list[Packet]]:
|
||||
"""
|
||||
Convert a list of chat messages into simplified streaming packets organized by message.
|
||||
Each inner list contains packets for a single assistant message.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages from the session
|
||||
|
||||
Returns:
|
||||
List of lists of simplified packets, where each inner list represents one assistant message
|
||||
"""
|
||||
packets_by_message: list[list[Packet]] = []
|
||||
|
||||
for message in messages:
|
||||
if message.message_type == MessageType.ASSISTANT:
|
||||
message_packets = create_simplified_packets_for_message(message, 0)
|
||||
if message_packets: # Only add if there are actual packets
|
||||
packets_by_message.append(message_packets)
|
||||
|
||||
return packets_by_message
|
||||
@@ -17,6 +17,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.internet_search.providers import (
|
||||
get_available_providers,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -63,6 +66,15 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
if (bool(get_available_providers()))
|
||||
else []
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=KnowledgeGraphTool,
|
||||
description=(
|
||||
"The Knowledge Graph Search Action allows the assistant to search the knowledge graph for information."
|
||||
"This tool should only be used by the Deep Research Agent, not via tool calling."
|
||||
),
|
||||
in_code_tool_id=KnowledgeGraphTool.__name__,
|
||||
display_name=KnowledgeGraphTool._DISPLAY_NAME,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -106,27 +118,37 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
logger.notice("All built-in tools are loaded/verified.")
|
||||
|
||||
|
||||
def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
||||
def get_builtin_tool(
|
||||
db_session: Session,
|
||||
tool_type: Type[
|
||||
SearchTool | ImageGenerationTool | InternetSearchTool | KnowledgeGraphTool
|
||||
],
|
||||
) -> ToolDBModel:
|
||||
"""
|
||||
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
|
||||
Retrieves a built-in tool from the database based on the tool type.
|
||||
"""
|
||||
search_tool_id = next(
|
||||
tool_id = next(
|
||||
(
|
||||
tool["in_code_tool_id"]
|
||||
for tool in BUILT_IN_TOOLS
|
||||
if tool["cls"].__name__ == SearchTool.__name__
|
||||
if tool["cls"].__name__ == tool_type.__name__
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not search_tool_id:
|
||||
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
|
||||
if not tool_id:
|
||||
raise RuntimeError(
|
||||
f"Tool type {tool_type.__name__} not found in the BUILT_IN_TOOLS list."
|
||||
)
|
||||
|
||||
search_tool = db_session.execute(
|
||||
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
|
||||
db_tool = db_session.execute(
|
||||
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == tool_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
return search_tool
|
||||
if not db_tool:
|
||||
raise RuntimeError(f"Tool type {tool_type.__name__} not found in the database.")
|
||||
|
||||
return db_tool
|
||||
|
||||
|
||||
def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
@@ -136,10 +158,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
Persona objects that were created before the concept of Tools were added.
|
||||
"""
|
||||
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
|
||||
search_tool = get_search_tool(db_session)
|
||||
|
||||
if not search_tool:
|
||||
raise RuntimeError("SearchTool not found in the database.")
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
|
||||
# Fetch all Personas that need the SearchTool added
|
||||
personas_to_update = (
|
||||
|
||||
@@ -20,6 +20,11 @@ OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||
|
||||
|
||||
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
@@ -41,6 +42,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import compute_all_tool_tokens
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
@@ -265,6 +269,14 @@ def construct_tools(
|
||||
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
|
||||
)
|
||||
|
||||
# Handle KG Tool
|
||||
elif tool_cls.__name__ == KnowledgeGraphTool.__name__:
|
||||
if persona.name != TMP_DRALPHA_PERSONA_NAME:
|
||||
raise ValueError(
|
||||
f"Knowledge Graph Tool should only be used by the '{TMP_DRALPHA_PERSONA_NAME}' Agent."
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [KnowledgeGraphTool()]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
|
||||
@@ -17,6 +17,8 @@ from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -77,6 +79,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
@@ -86,6 +89,7 @@ class CustomTool(BaseTool):
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
self._user_oauth_token = user_oauth_token
|
||||
self._id = id
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
@@ -107,6 +111,10 @@ class CustomTool(BaseTool):
|
||||
if self._user_oauth_token:
|
||||
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
@@ -382,11 +390,27 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
openapi_schema_str = json.dumps(openapi_schema)
|
||||
|
||||
with get_session_with_current_tenant() as temp_db_session:
|
||||
tools = get_tools(temp_db_session)
|
||||
tool_id: int | None = None
|
||||
for tool in tools:
|
||||
if tool.openapi_schema and (
|
||||
json.dumps(tool.openapi_schema) == openapi_schema_str
|
||||
):
|
||||
tool_id = tool.id
|
||||
break
|
||||
if not tool_id:
|
||||
raise ValueError(f"Tool with openapi_schema {openapi_schema_str} not found")
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
method_spec,
|
||||
url,
|
||||
custom_headers,
|
||||
id=tool_id,
|
||||
method_spec=method_spec,
|
||||
base_url=url,
|
||||
custom_headers=custom_headers,
|
||||
user_oauth_token=user_oauth_token,
|
||||
)
|
||||
for method_spec in method_specs
|
||||
|
||||
@@ -13,6 +13,8 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
@@ -112,6 +114,22 @@ class ImageGenerationTool(Tool[None]):
|
||||
self.additional_headers = additional_headers
|
||||
self.output_format = output_format
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == ImageGenerationTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Image Generation tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
@@ -143,8 +144,23 @@ class InternetSearchTool(Tool[None]):
|
||||
)
|
||||
)
|
||||
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == InternetSearchTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Internet Search tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
"""For explicit tool calling"""
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
class KnowledgeGraphTool(Tool[None]):
|
||||
_NAME = "run_kg_search"
|
||||
_DESCRIPTION = "Search the knowledge graph for information. Never call this tool."
|
||||
_DISPLAY_NAME = "Knowledge Graph Search"
|
||||
|
||||
def __init__(self) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == KnowledgeGraphTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Knowledge Graph tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"type": "string",
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
"required": [QUERY_FIELD],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
@@ -34,6 +34,7 @@ from onyx.context.search.models import UserFileFilters
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -162,6 +163,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == SearchTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError("Search tool not found. This should never happen.")
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user