mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 23:12:43 +00:00
Compare commits
92 Commits
cli/v0.1.2
...
dr_v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff0c78bb02 | ||
|
|
73cecf8c05 | ||
|
|
516ae99225 | ||
|
|
1d7ec49e55 | ||
|
|
de82ad97e0 | ||
|
|
7b37e72b9d | ||
|
|
09d672ff22 | ||
|
|
b028b25737 | ||
|
|
07768d5484 | ||
|
|
5ca8ca2b1e | ||
|
|
62872e58ae | ||
|
|
5f66a27c67 | ||
|
|
c21fa21958 | ||
|
|
cd6577c3ca | ||
|
|
16406f0ebd | ||
|
|
4ae5bb1e6b | ||
|
|
b0c95ec876 | ||
|
|
397d30c802 | ||
|
|
f13b08b461 | ||
|
|
e66245ec13 | ||
|
|
c64c6368c1 | ||
|
|
b2fe55c8f8 | ||
|
|
2b661441d7 | ||
|
|
f83f06228b | ||
|
|
fabfa8d166 | ||
|
|
994e7f7666 | ||
|
|
c81a7e1ef2 | ||
|
|
1d7d2f06d8 | ||
|
|
916d6cb119 | ||
|
|
6d3542ded1 | ||
|
|
e5dbfc34c3 | ||
|
|
1aad7f44d2 | ||
|
|
a0d6d0b922 | ||
|
|
588023a1f6 | ||
|
|
e4c2427728 | ||
|
|
bf77da26fc | ||
|
|
abfecde097 | ||
|
|
3f4936ad0a | ||
|
|
3b8d16a136 | ||
|
|
322e8668da | ||
|
|
d1dcad60d6 | ||
|
|
7b3bdbdf83 | ||
|
|
8b09fb0cef | ||
|
|
a2dd1bbf4f | ||
|
|
828231815a | ||
|
|
d48cbc2b79 | ||
|
|
991bd4f8bf | ||
|
|
74418b84a2 | ||
|
|
df1c40c791 | ||
|
|
c253844500 | ||
|
|
e972fb3e07 | ||
|
|
726211c27d | ||
|
|
c0435ddfd6 | ||
|
|
48dc934c35 | ||
|
|
3a575a92d5 | ||
|
|
de4a9e4687 | ||
|
|
c330152417 | ||
|
|
dca39f27a6 | ||
|
|
d3cc27846a | ||
|
|
fedc665b88 | ||
|
|
614672f357 | ||
|
|
6aca9ee005 | ||
|
|
f9f64fb1a5 | ||
|
|
4a63e631cd | ||
|
|
3d5586d623 | ||
|
|
6c4eb17b5d | ||
|
|
0917d9acd3 | ||
|
|
89ea0f8d48 | ||
|
|
31ae6f1eb1 | ||
|
|
1b8d246afb | ||
|
|
05e55559d8 | ||
|
|
241b8d062c | ||
|
|
6359d2f2d6 | ||
|
|
83325f9012 | ||
|
|
0b26ed602d | ||
|
|
2b69d1ba52 | ||
|
|
27cd1d44dc | ||
|
|
b5ddf31742 | ||
|
|
ce1c80148b | ||
|
|
bb95c46015 | ||
|
|
e8a593c315 | ||
|
|
bb1b12988c | ||
|
|
72bbcabedf | ||
|
|
2ee98ba795 | ||
|
|
594bbdb167 | ||
|
|
d5c67b6f50 | ||
|
|
9c7638ceba | ||
|
|
b1488ddccc | ||
|
|
d9a9818b9a | ||
|
|
4bd3b8b0bb | ||
|
|
da3979fc41 | ||
|
|
ffed8b4300 |
@@ -0,0 +1,91 @@
|
||||
"""add research agent database tables and chat message research fields
|
||||
|
||||
Revision ID: 5ae8240accb3
|
||||
Revises: 62c3a055a141
|
||||
Create Date: 2025-08-06 14:29:24.691388
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5ae8240accb3"
|
||||
down_revision = "62c3a055a141"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add research_type and research_plan columns to chat_message table
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_type", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration table
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create research_agent_iteration_sub_step table
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column(
|
||||
"primary_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"parent_question_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"sub_step_tool_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("tool.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order
|
||||
op.drop_table("research_agent_iteration_sub_step")
|
||||
op.drop_table("research_agent_iteration")
|
||||
|
||||
# Remove columns from chat_message table
|
||||
op.drop_column("chat_message", "research_plan")
|
||||
op.drop_column("chat_message", "research_type")
|
||||
12
backend/onyx/agents/agent_search/basic/models.py
Normal file
12
backend/onyx/agents/agent_search/basic/models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
@@ -6,6 +6,9 @@ from pydantic import BaseModel
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
@@ -18,11 +21,15 @@ class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
unused: bool = True
|
||||
query_override: str | None = None
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
full_answer: str | None
|
||||
cited_references: list[InferenceSection] | None
|
||||
retrieved_documents: list[LlmDoc] | None
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
@@ -5,6 +5,7 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
@@ -24,7 +25,7 @@ def process_llm_stream(
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
@@ -61,4 +62,6 @@ def process_llm_stream(
|
||||
)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
|
||||
49
backend/onyx/agents/agent_search/dr/conditional_edges.py
Normal file
49
backend/onyx/agents/agent_search/dr/conditional_edges.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool = state.tools_used[-1]
|
||||
try:
|
||||
next_path = DRPath(next_tool)
|
||||
except ValueError:
|
||||
next_path = DRPath.GENERIC_TOOL
|
||||
|
||||
# handle END
|
||||
if next_path == DRPath.END:
|
||||
return END
|
||||
|
||||
# handle invalid paths
|
||||
if next_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_path
|
||||
in (DRPath.INTERNAL_SEARCH, DRPath.INTERNET_SEARCH, DRPath.KNOWLEDGE_GRAPH)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
|
||||
return next_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return END
|
||||
29
backend/onyx/agents/agent_search/dr/constants.py
Normal file
29
backend/onyx/agents/agent_search/dr/constants.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "HIGH_LEVEL PLAN:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.INTERNET_SEARCH: 1.5,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 6.0,
|
||||
}
|
||||
114
backend/onyx/agents/agent_search/dr/dr_prompt_builder.py
Normal file
114
backend/onyx/agents/agent_search/dr/dr_prompt_builder.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
|
||||
tool_differentiations: list[str] = []
|
||||
for tool_1 in available_tools:
|
||||
for tool_2 in available_tools:
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS:
|
||||
tool_differentiations.append(
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
)
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools:
|
||||
if not entity_types_string or not relationship_types_string:
|
||||
raise ValueError(
|
||||
"Entity types and relationship types must be provided if the Knowledge Graph is used."
|
||||
)
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string,
|
||||
possible_relationships=relationship_types_string,
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
20
backend/onyx/agents/agent_search/dr/enums.py
Normal file
20
backend/onyx/agents/agent_search/dr/enums.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "CLARIFIER"
|
||||
ORCHESTRATOR = "ORCHESTRATOR"
|
||||
INTERNAL_SEARCH = "INTERNAL_SEARCH"
|
||||
GENERIC_TOOL = "GENERIC_TOOL"
|
||||
KNOWLEDGE_GRAPH = "KNOWLEDGE_GRAPH"
|
||||
INTERNET_SEARCH = "INTERNET_SEARCH"
|
||||
CLOSER = "CLOSER"
|
||||
END = "END"
|
||||
73
backend/onyx/agents/agent_search/dr/graph_builder.py
Normal file
73
backend/onyx/agents/agent_search/dr/graph_builder.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.dr.states import MainInput
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
|
||||
dr_is_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_is_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
|
||||
return graph
|
||||
115
backend/onyx/agents/agent_search/dr/models.py
Normal file
115
backend/onyx/agents/agent_search/dr/models.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Time budget options for agent search operations"""
|
||||
|
||||
FAST = "fast"
|
||||
SHALLOW = "shallow"
|
||||
DEEP = "deep"
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
461
backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py
Normal file
461
backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
|
||||
from onyx.agents.agent_search.dr.constants import CLARIFICATION_REQUEST_PREFIX
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _format_tool_name(tool_name: str) -> str:
|
||||
"""Convert tool name to LLM-friendly format."""
|
||||
name = tool_name.replace(" ", "_")
|
||||
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
|
||||
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
|
||||
return name.upper()
|
||||
|
||||
|
||||
def _get_available_tools(
|
||||
graph_config: GraphConfig, kg_enabled: bool
|
||||
) -> dict[str, OrchestratorTool]:
|
||||
|
||||
available_tools: dict[str, OrchestratorTool] = {}
|
||||
for tool in graph_config.tooling.tools:
|
||||
tool_info = OrchestratorTool(
|
||||
tool_id=tool.id,
|
||||
name=tool.name,
|
||||
llm_path=_format_tool_name(tool.name),
|
||||
path=DRPath.GENERIC_TOOL,
|
||||
description=tool.description,
|
||||
metadata={},
|
||||
cost=1.0,
|
||||
tool_object=tool,
|
||||
)
|
||||
|
||||
if isinstance(tool, CustomTool):
|
||||
# tool_info.metadata["summary_signature"] = CUSTOM_TOOL_RESPONSE_ID
|
||||
pass
|
||||
elif isinstance(tool, InternetSearchTool):
|
||||
# tool_info.metadata["summary_signature"] = (
|
||||
# INTERNET_SEARCH_RESPONSE_SUMMARY_ID
|
||||
# )
|
||||
tool_info.llm_path = DRPath.INTERNET_SEARCH.value
|
||||
tool_info.path = DRPath.INTERNET_SEARCH
|
||||
elif isinstance(tool, SearchTool):
|
||||
# tool_info.metadata["summary_signature"] = SEARCH_RESPONSE_SUMMARY_ID
|
||||
tool_info.llm_path = DRPath.INTERNAL_SEARCH.value
|
||||
tool_info.path = DRPath.INTERNAL_SEARCH
|
||||
elif isinstance(tool, KnowledgeGraphTool):
|
||||
if not kg_enabled:
|
||||
logger.warning("KG must be enabled to use KG search tool, skipping")
|
||||
continue
|
||||
tool_info.llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
tool_info.path = DRPath.KNOWLEDGE_GRAPH
|
||||
else:
|
||||
logger.warning(f"Tool {tool.name} ({type(tool)}) is not supported")
|
||||
continue
|
||||
|
||||
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
|
||||
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
|
||||
|
||||
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
|
||||
available_tools[tool_info.llm_path] = tool_info
|
||||
|
||||
# make sure KG isn't enabled without internal search
|
||||
if (
|
||||
DRPath.KNOWLEDGE_GRAPH.value in available_tools
|
||||
and DRPath.INTERNAL_SEARCH.value not in available_tools
|
||||
):
|
||||
raise ValueError(
|
||||
"The Knowledge Graph is not supported without internal search tool"
|
||||
)
|
||||
|
||||
# add CLOSER tool, which is always available
|
||||
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
|
||||
tool_id=-1,
|
||||
name="closer",
|
||||
llm_path=DRPath.CLOSER.value,
|
||||
path=DRPath.CLOSER,
|
||||
description=TOOL_DESCRIPTION[DRPath.CLOSER],
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
return available_tools
|
||||
|
||||
|
||||
def _get_existing_clarification_request(
|
||||
graph_config: GraphConfig,
|
||||
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
|
||||
"""
|
||||
Returns the clarification info, original question, and updated chat history if
|
||||
a clarification request and response exists, otherwise returns None.
|
||||
"""
|
||||
# check for clarification request and response in message history
|
||||
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
|
||||
if (
|
||||
len(previous_raw_messages) < 2
|
||||
or previous_raw_messages[-1].message_type != MessageType.ASSISTANT
|
||||
or CLARIFICATION_REQUEST_PREFIX not in previous_raw_messages[-1].message
|
||||
):
|
||||
return None
|
||||
|
||||
# get the clarification request and response
|
||||
previous_messages = graph_config.inputs.prompt_builder.message_history
|
||||
last_message = previous_raw_messages[-1].message
|
||||
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=last_message.split(CLARIFICATION_REQUEST_PREFIX, 1)[
|
||||
1
|
||||
].strip(),
|
||||
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
chat_history_string = "(No chat history yet available)"
|
||||
|
||||
# get the original user query and chat history string before the original query
|
||||
# e.g., if history = [user query, assistant clarification request, user clarification response],
|
||||
# previous_messages = [user query, assistant clarification request], we want the user query
|
||||
for i, message in enumerate(reversed(previous_messages), 1):
|
||||
if (
|
||||
isinstance(message, HumanMessage)
|
||||
and message.content
|
||||
and isinstance(message.content, str)
|
||||
):
|
||||
original_question = message.content
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history[:-i],
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
break
|
||||
|
||||
return clarification, original_question, chat_history_string
|
||||
|
||||
|
||||
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_any_knowledge_retrieval_and_any_action_tool",
|
||||
"description": "Use this tool to get any external information \
|
||||
that is relevant to the question, or for any action to be taken.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The request to be made to the tool",
|
||||
},
|
||||
},
|
||||
"required": ["request"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def clarifier(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationUpdate:
|
||||
"""
|
||||
Perform a quick search on the question as is and see whether a set of clarification
|
||||
questions is needed. For now this is based on the models
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
|
||||
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
# get the connected tools and format for the Deep Research flow
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
available_tools = _get_available_tools(graph_config, kg_enabled)
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
db_session = graph_config.persistence.db_session
|
||||
active_source_types = fetch_unique_document_sources(db_session)
|
||||
|
||||
if not active_source_types:
|
||||
raise ValueError("No active source types found")
|
||||
|
||||
active_source_types_descriptions = [
|
||||
DocumentSourceDescription[source_type] for source_type in active_source_types
|
||||
]
|
||||
|
||||
# Quick evaluation whether the query should be answered or rejected
|
||||
# query_evaluation_prompt = QUERY_EVALUATION_PROMPT.replace(
|
||||
# "---query---", original_question
|
||||
# )
|
||||
|
||||
# try:
|
||||
# evaluation_response = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=query_evaluation_prompt,
|
||||
# schema=QueryEvaluationResponse,
|
||||
# timeout_override=10,
|
||||
# max_tokens=100,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in clarification generation: {e}")
|
||||
# raise e
|
||||
|
||||
# if not evaluation_response.query_permitted:
|
||||
|
||||
# rejection_prompt = QUERY_REJECTION_PROMPT.replace(
|
||||
# "---query---", original_question
|
||||
# ).replace("---reasoning---", evaluation_response.reasoning)
|
||||
|
||||
# _ = run_with_timeout(
|
||||
# 80,
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=rejection_prompt,
|
||||
# event_name="basic_response",
|
||||
# writer=writer,
|
||||
# agent_answer_level=0,
|
||||
# agent_answer_question_num=0,
|
||||
# agent_answer_type="agent_level_answer",
|
||||
# timeout_override=60,
|
||||
# max_tokens=None,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# logger.info(
|
||||
# f"""Rejected query: {original_question}, \
|
||||
# Rejection reason: {evaluation_response.reasoning}"""
|
||||
# )
|
||||
|
||||
# return OrchestrationUpdate(
|
||||
# original_question=original_question,
|
||||
# chat_history_string="",
|
||||
# tools_used=[DRPath.END.value],
|
||||
# query_list=[],
|
||||
# )
|
||||
|
||||
# get clarification (unless time budget is FAST)
|
||||
|
||||
# Verification of whether the LLM can answer the question without any external tool or knowledge
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
if not use_tool_calling_llm:
|
||||
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
|
||||
question=original_question, chat_history_string=chat_history_string
|
||||
)
|
||||
|
||||
initial_decision_tokens, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING, decision_prompt
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
initial_decision_str = cast(str, merge_content(*initial_decision_tokens))
|
||||
|
||||
if len(initial_decision_str.replace(" ", "")) > 0:
|
||||
return OrchestrationUpdate(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
|
||||
question=original_question, chat_history_string=chat_history_string
|
||||
)
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
EVAL_SYSTEM_PROMPT_W_TOOL_CALLING, decision_prompt
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
writer,
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None or len(tool_message.tool_calls) == 0:
|
||||
return OrchestrationUpdate(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
)
|
||||
|
||||
# Continue, as external knowledge is required.
|
||||
clarification = None
|
||||
if research_type != ResearchType.THOUGHTFUL:
|
||||
result = _get_existing_clarification_request(graph_config)
|
||||
if result is not None:
|
||||
clarification, original_question, chat_history_string = result
|
||||
else:
|
||||
# generate clarification questions if needed
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
base_clarification_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.CLARIFICATION,
|
||||
research_type,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
clarification_prompt = base_clarification_prompt.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
clarification_response = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=clarification_prompt,
|
||||
schema=ClarificationGenerationResponse,
|
||||
timeout_override=25,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in clarification generation: {e}")
|
||||
raise e
|
||||
|
||||
if (
|
||||
clarification_response.clarification_needed
|
||||
and clarification_response.clarification_question
|
||||
):
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=clarification_response.clarification_question,
|
||||
clarification_response=None,
|
||||
)
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=(
|
||||
f"{CLARIFICATION_REQUEST_PREFIX} "
|
||||
f"{clarification.clarification_question}\n\n"
|
||||
),
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
if (
|
||||
clarification
|
||||
and clarification.clarification_question
|
||||
and clarification.clarification_response is None
|
||||
):
|
||||
next_tool = DRPath.END.value
|
||||
else:
|
||||
next_tool = DRPath.ORCHESTRATOR.value
|
||||
|
||||
return OrchestrationUpdate(
|
||||
original_question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
tools_used=[next_tool],
|
||||
query_list=[],
|
||||
iteration_nr=0,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="clarifier",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
available_tools=available_tools,
|
||||
active_source_types=active_source_types,
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
)
|
||||
364
backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py
Normal file
364
backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
|
||||
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
|
||||
from onyx.agents.agent_search.dr.states import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import create_tool_call_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def orchestrator(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to decide the next step in the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.original_question
|
||||
if not question:
|
||||
raise ValueError("Question is required for orchestrator")
|
||||
|
||||
plan_of_record = state.plan_of_record
|
||||
clarification = state.clarification
|
||||
iteration_nr = state.iteration_nr + 1
|
||||
research_type = graph_config.behavior.research_type
|
||||
remaining_time_budget = state.remaining_time_budget
|
||||
chat_history_string = state.chat_history_string or "(No chat history yet available)"
|
||||
answer_history_string = (
|
||||
aggregate_context(state.iteration_responses, include_documents=True).context
|
||||
or "(No answer history yet available)"
|
||||
)
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
questions = [
|
||||
f"{iteration_response.tool}: {iteration_response.question}"
|
||||
for iteration_response in state.iteration_responses
|
||||
if len(iteration_response.question) > 0
|
||||
]
|
||||
|
||||
question_history_string = (
|
||||
"\n".join(f" - {question}" for question in questions)
|
||||
if questions
|
||||
else "(No question history yet available)"
|
||||
)
|
||||
|
||||
prompt_question = get_prompt_question(question, clarification)
|
||||
|
||||
gaps_str = (
|
||||
("\n - " + "\n - ".join(state.gaps))
|
||||
if state.gaps
|
||||
else "(No explicit gaps were pointed out so far)"
|
||||
)
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# default to closer
|
||||
next_tool = DRPath.CLOSER.value
|
||||
query_list = ["Answer the question with the information you have."]
|
||||
decision_prompt = None
|
||||
|
||||
reasoning_result = "(No reasoning result provided yet.)"
|
||||
tool_calls_string = "(No tool calls provided yet.)"
|
||||
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
if iteration_nr == 1:
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
|
||||
|
||||
elif iteration_nr > 1:
|
||||
# for each iteration past the first one, we need to see whether we
|
||||
# have enough information to answer the question.
|
||||
# if we do, we can stop the iteration and return the answer.
|
||||
# if we do not, we need to continue the iteration.
|
||||
|
||||
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_REASONING,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
AgentAnswerPiece(
|
||||
answer_piece="\n\n\nREASONING TO STOP/CONTINUE:\n\n\n",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
reasoning_prompt = base_reasoning_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
)
|
||||
|
||||
reasoning_tokens: list[str] = [""]
|
||||
|
||||
reasoning_tokens, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=reasoning_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
reasoning_result = cast(str, merge_content(*reasoning_tokens))
|
||||
|
||||
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
)
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
reasoning_result=reasoning_result,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=decision_prompt,
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=35,
|
||||
# max_tokens=2500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool = next_step.tool
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool, query_list)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
remaining_time_budget -= available_tools[next_tool].cost
|
||||
else:
|
||||
if iteration_nr == 1 and not plan_of_record:
|
||||
# by default, we start a new iteration, but if there is a feedback request,
|
||||
# we start a new iteration 0 again (set a bit later)
|
||||
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
|
||||
|
||||
base_plan_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.PLAN,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
plan_generation_prompt = base_plan_prompt.build(
|
||||
question=prompt_question,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
plan_of_record = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=plan_generation_prompt,
|
||||
schema=OrchestrationPlan,
|
||||
timeout_override=25,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in plan generation: {e}")
|
||||
raise
|
||||
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=f"{HIGH_LEVEL_PLAN_PREFIX} {plan_of_record.plan}\n\n",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if not plan_of_record:
|
||||
raise ValueError(
|
||||
"Plan information is required for iterative decision making"
|
||||
)
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
answer_history_string=answer_history_string,
|
||||
question_history_string=question_history_string,
|
||||
question=prompt_question,
|
||||
iteration_nr=str(iteration_nr),
|
||||
current_plan_of_record_string=plan_of_record.plan,
|
||||
chat_history_string=chat_history_string,
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
gaps=gaps_str,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=decision_prompt,
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=15,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool = next_step.tool
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
reasoning_result = orchestrator_action.reasoning
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool, query_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
remaining_time_budget -= available_tools[next_tool].cost
|
||||
else:
|
||||
reasoning_result = "Time to wrap up."
|
||||
|
||||
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_PURPOSE,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
|
||||
question=prompt_question,
|
||||
reasoning_result=reasoning_result,
|
||||
tool_calls=tool_calls_string,
|
||||
)
|
||||
|
||||
purpose_tokens: list[str] = [""]
|
||||
|
||||
# Write short purpose
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=f"\n\n\nITERATION {iteration_nr}:\n\n\n",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
try:
|
||||
purpose_tokens, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=orchestration_next_step_purpose_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in orchestration next step purpose: {e}")
|
||||
raise e
|
||||
|
||||
purpose = cast(str, merge_content(*purpose_tokens))
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[next_tool],
|
||||
query_list=query_list or [],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning=reasoning_result,
|
||||
purpose=purpose,
|
||||
)
|
||||
],
|
||||
)
|
||||
270
backend/onyx/agents/agent_search/dr/nodes/dr_a2_closer.py
Normal file
270
backend/onyx/agents/agent_search/dr/nodes/dr_a2_closer.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_citation_format_list
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState, graph_config: GraphConfig, aggregated_context: AggregatedDRContext
|
||||
) -> None:
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# update chat message and save iterations
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise ValueError("Chat message with id not found") # should never happen
|
||||
|
||||
chat_message.research_type = research_type
|
||||
chat_message.research_plan = plan_of_record_dict
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
parent_question_id=None,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=create_citation_format_list(
|
||||
[doc for doc in iteration_answer.cited_documents.values()]
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
iteration_responses_string = aggregated_context.context
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=test_info_complete_prompt,
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=40,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
|
||||
# Stream out docs - TODO: Improve this with new frontend
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=base_question,
|
||||
top_sections=all_cited_documents,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.KEYWORD, # unused
|
||||
final_filters=IndexFilters(access_control_list=None), # unused
|
||||
recency_bias_multiplier=1.0, # unused
|
||||
),
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Change status from "searching for" to "searched for"
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ToolCallFinalResult(
|
||||
tool_name=SearchTool._NAME,
|
||||
tool_args={"query": base_question},
|
||||
tool_result=[], # unused
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Generate final answer
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
AgentAnswerPiece(
|
||||
answer_piece="\n\n\nFINAL ANSWER:\n\n\n",
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
final_answer_prompt = FINAL_ANSWER_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO: fix citations for non-document returning tools (right now, it cites iteration nr)
|
||||
streamed_output = run_with_timeout(
|
||||
240,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=final_answer_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
final_answer = "".join(streamed_output[0])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
dispatch_main_answer_stop_info(level=0, writer=writer)
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(state, graph_config, aggregated_context)
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
71
backend/onyx/agents/agent_search/dr/states.py
Normal file
71
backend/onyx/agents/agent_search/dr/states.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationUpdate,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,254 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.states import AnswerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> AnswerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
)
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=base_search_processing_prompt,
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=5,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = [
|
||||
DocumentSource(source_type)
|
||||
for source_type in search_processing.specified_source_types
|
||||
]
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
|
||||
# write_custom_event(
|
||||
# "basic_response",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=(
|
||||
# f"SUB-QUESTION {iteration_nr}.{parallelization_nr} "
|
||||
# f"(SEARCH): {branch_query}\n\n"
|
||||
# f"REWRITTEN QUERY: {rewritten_query}\n\n"
|
||||
# f"PREDICTED SOURCE TYPES: {specified_source_types}\n\n"
|
||||
# f"PREDICTED TIME FILTER: {implied_time_filter}\n\n"
|
||||
# " --- \n\n"
|
||||
# ),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
|
||||
break
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=search_prompt,
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# write_custom_event(
|
||||
# "basic_response",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=f"ANSWERED {iteration_nr}.{parallelization_nr}\n\n",
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return AnswerUpdate(
|
||||
iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.states import AnswerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> AnswerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.llm_path
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_response="(No tool response yet. You need to call the tool to answer the question.)",
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=40,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
# summarise tool result
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
tool_summary_prompt, timeout_override=40
|
||||
).content
|
||||
).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return AnswerUpdate(
|
||||
iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Internet Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,194 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def internet_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
# write_custom_event(
|
||||
# "basic_response",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=(
|
||||
# f"SUB-QUESTION {iteration_nr}.{parallelization_nr} "
|
||||
# f"(INTERNET SEARCH): {search_query}\n\n"
|
||||
# ),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
internet_search_tool = cast(InternetSearchTool, is_tool_info.tool_object)
|
||||
|
||||
if internet_search_tool.provider is None:
|
||||
raise ValueError(
|
||||
"internet_search_tool.provider is not set. This should not happen."
|
||||
)
|
||||
|
||||
# Update search parameters
|
||||
internet_search_tool.max_chunks = 10
|
||||
internet_search_tool.provider.num_results = 10
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
for tool_response in internet_search_tool.run(internet_search_query=search_query):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
# stream_write_step_answer_explicit(writer, step_nr=1, answer=full_answer)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=search_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=search_prompt,
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# write_custom_event(
|
||||
# "basic_response",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=f"ANSWERED {iteration_nr}.{parallelization_nr}\n\n",
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
reasoning = ""
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_2_act import (
|
||||
internet_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_is_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("act", internet_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,96 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
logger.debug(
|
||||
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
kg_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
kb_graph = kb_graph_builder().compile()
|
||||
|
||||
kb_results = kb_graph.invoke(
|
||||
input=KbMainInput(question=search_query, individual_flow=False),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = kb_results.get("final_answer") or "No answer provided"
|
||||
claims: list[str] = []
|
||||
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
# if citation is empty, the answer must have come from the KG rather than a doc
|
||||
# in that case, simply cite the docs returned by the KG
|
||||
if not citation_numbers:
|
||||
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
if citation_number <= len(retrieved_docs)
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=kg_tool_info.llm_path,
|
||||
tool_id=kg_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=None,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel search for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
|
||||
kg_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
|
||||
kg_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
|
||||
kg_search_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_kg_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for KG Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", kg_search_branch)
|
||||
|
||||
graph.add_node("act", kg_search)
|
||||
|
||||
graph.add_node("reducer", kg_search_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
42
backend/onyx/agents/agent_search/dr/sub_agents/states.py
Normal file
42
backend/onyx/agents/agent_search/dr/sub_agents/states.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
|
||||
class SubAgentUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class BranchUpdate(LoggerUpdate):
|
||||
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class SubAgentInput(LoggerUpdate):
|
||||
iteration_nr: int = 0
|
||||
query_list: list[str] = []
|
||||
context: str | None = None
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
|
||||
|
||||
class SubAgentMainState(
|
||||
# This includes the core state
|
||||
SubAgentInput,
|
||||
SubAgentUpdate,
|
||||
BranchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class BranchInput(SubAgentInput):
|
||||
parallelization_nr: int = 0
|
||||
branch_question: str | None = None
|
||||
|
||||
|
||||
class CustomToolBranchInput(LoggerUpdate):
|
||||
tool_info: OrchestratorTool
|
||||
229
backend/onyx/agents/agent_search/dr/utils.py
Normal file
229
backend/onyx/agents/agent_search/dr/utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import re
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
CITATION_PREFIX = "CITE:"
|
||||
|
||||
|
||||
def extract_document_citations(
|
||||
answer: str, claims: list[str]
|
||||
) -> tuple[list[int], str, list[str]]:
|
||||
"""
|
||||
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
|
||||
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
|
||||
etc., to help with citation deduplication later on.
|
||||
"""
|
||||
citations: set[int] = set()
|
||||
|
||||
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
|
||||
# This regex matches:
|
||||
# - \[(\d+)\] for single citations like [1]
|
||||
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
|
||||
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||
|
||||
def _extract_and_replace(match: re.Match[str]) -> str:
|
||||
numbers = [int(num) for num in match.group(1).split(",")]
|
||||
citations.update(numbers)
|
||||
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
|
||||
|
||||
new_answer = pattern.sub(_extract_and_replace, answer)
|
||||
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
|
||||
|
||||
return list(citations), new_answer, new_claims
|
||||
|
||||
|
||||
def aggregate_context(
|
||||
iteration_responses: list[IterationAnswer], include_documents: bool = False
|
||||
) -> AggregatedDRContext:
|
||||
"""
|
||||
Converts the iteration response into a single string with unified citations.
|
||||
For example,
|
||||
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
|
||||
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
|
||||
Output:
|
||||
it 1: the answer is x [1][2].
|
||||
it 2: blah blah [2][3]
|
||||
[1]: doc_xyz
|
||||
[2]: doc_abc
|
||||
[3]: doc_pqr
|
||||
"""
|
||||
# dedupe and merge inference section contents
|
||||
unrolled_inference_sections: list[InferenceSection] = []
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
for cited_doc in iteration_response.cited_documents.values():
|
||||
unrolled_inference_sections.append(cited_doc)
|
||||
cited_doc.center_chunk.score = None # None means maintain order
|
||||
|
||||
global_documents = dedup_inference_section_list(unrolled_inference_sections)
|
||||
global_citations = {
|
||||
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
|
||||
}
|
||||
|
||||
# build output string
|
||||
output_strings: list[str] = []
|
||||
global_iteration_responses: list[IterationAnswer] = []
|
||||
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
# add basic iteration info
|
||||
output_strings.append(
|
||||
f"Iteration: {iteration_response.iteration_nr}, "
|
||||
f"Question {iteration_response.parallelization_nr}"
|
||||
)
|
||||
output_strings.append(f"Tool: {iteration_response.tool}")
|
||||
output_strings.append(f"Question: {iteration_response.question}")
|
||||
|
||||
# get answer and claims with global citations
|
||||
answer_str = iteration_response.answer
|
||||
claims = iteration_response.claims or []
|
||||
|
||||
iteration_citations: list[int] = []
|
||||
for local_number, cited_doc in iteration_response.cited_documents.items():
|
||||
global_number = global_citations[cited_doc.center_chunk.document_id]
|
||||
# translate local citations to global citations
|
||||
answer_str = answer_str.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
claims = [
|
||||
claim.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
iteration_citations.append(global_number)
|
||||
|
||||
# add answer, claims, and citation info
|
||||
if answer_str:
|
||||
output_strings.append(f"Answer: {answer_str}")
|
||||
if claims:
|
||||
output_strings.append(
|
||||
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
|
||||
or "No claims provided"
|
||||
)
|
||||
if not answer_str and not claims:
|
||||
output_strings.append(
|
||||
"Retrieved documents: "
|
||||
+ (
|
||||
"".join(
|
||||
f"[{global_number}]"
|
||||
for global_number in sorted(iteration_citations)
|
||||
)
|
||||
or "No documents retrieved"
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
# save global iteration response
|
||||
global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=iteration_response.tool,
|
||||
tool_id=iteration_response.tool_id,
|
||||
iteration_nr=iteration_response.iteration_nr,
|
||||
parallelization_nr=iteration_response.parallelization_nr,
|
||||
question=iteration_response.question,
|
||||
reasoning=iteration_response.reasoning,
|
||||
answer=answer_str,
|
||||
cited_documents={
|
||||
global_citations[doc.center_chunk.document_id]: doc
|
||||
for doc in iteration_response.cited_documents.values()
|
||||
},
|
||||
background_info=iteration_response.background_info,
|
||||
claims=claims,
|
||||
additional_data=iteration_response.additional_data,
|
||||
)
|
||||
)
|
||||
|
||||
# add document contents if requested
|
||||
if include_documents:
|
||||
if global_documents:
|
||||
output_strings.append("Cited document contents:")
|
||||
for doc in global_documents:
|
||||
output_strings.append(
|
||||
build_document_context(
|
||||
doc, global_citations[doc.center_chunk.document_id]
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
return AggregatedDRContext(
|
||||
context="\n".join(output_strings),
|
||||
cited_documents=global_documents,
|
||||
global_iteration_responses=global_iteration_responses,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
|
||||
"""
|
||||
Get the chat history (up to max_messages) as a string.
|
||||
"""
|
||||
# get past max_messages USER, ASSISTANT message pairs
|
||||
past_messages = chat_history[-max_messages * 2 :]
|
||||
return (
|
||||
"...\n"
|
||||
if len(chat_history) > len(past_messages)
|
||||
else ""
|
||||
"\n".join(
|
||||
("user" if isinstance(msg, HumanMessage) else "you")
|
||||
+ f": {str(msg.content).strip()}"
|
||||
for msg in past_messages
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_question(
|
||||
question: str, clarification: OrchestrationClarificationInfo | None
|
||||
) -> str:
|
||||
if clarification:
|
||||
clarification_question = clarification.clarification_question
|
||||
clarification_response = clarification.clarification_response
|
||||
return (
|
||||
f"Initial User Question: {question}\n"
|
||||
f"(Clarification Question: {clarification_question}\n"
|
||||
f"User Response: {clarification_response})"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
|
||||
"""
|
||||
Create a string representation of the tool call.
|
||||
"""
|
||||
questions_str = "\n - ".join(query_list)
|
||||
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
|
||||
|
||||
|
||||
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
|
||||
# Convert plan string to numbered dict format
|
||||
if not plan_text:
|
||||
return {}
|
||||
|
||||
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
|
||||
parts = re.split(r"(\d+[.)])", plan_text)
|
||||
plan_dict = {}
|
||||
|
||||
for i in range(
|
||||
1, len(parts), 2
|
||||
): # Skip empty first part, then take number and text pairs
|
||||
if i + 1 < len(parts):
|
||||
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
|
||||
text = parts[i + 1].strip()
|
||||
if text: # Only add if there's actual content
|
||||
plan_dict[number] = text
|
||||
|
||||
return plan_dict
|
||||
@@ -6,7 +6,12 @@ from langgraph.types import StreamWriter
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
KG_SEARCH_STEP_DESCRIPTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -95,14 +100,14 @@ def create_minimal_connected_query_graph(
|
||||
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
|
||||
|
||||
def stream_write_step_description(
|
||||
def stream_write_kg_search_description(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=STEP_DESCRIPTIONS[step_nr].description,
|
||||
sub_question=KG_SEARCH_STEP_DESCRIPTIONS[step_nr].description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
@@ -113,10 +118,12 @@ def stream_write_step_description(
|
||||
sleep(0.2)
|
||||
|
||||
|
||||
def stream_write_step_activities(
|
||||
def stream_write_kg_search_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
|
||||
for activity_nr, activity in enumerate(
|
||||
KG_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
|
||||
):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
@@ -129,23 +136,25 @@ def stream_write_step_activities(
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_activity_explicit(
|
||||
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
|
||||
def stream_write_basic_search_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity in STEP_DESCRIPTIONS[step_nr].activities:
|
||||
for activity_nr, activity in enumerate(
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
|
||||
):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=query_id,
|
||||
query_id=activity_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_answer_explicit(
|
||||
def stream_write_kg_search_answer_explicit(
|
||||
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
@@ -160,8 +169,8 @@ def stream_write_step_answer_explicit(
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
|
||||
def stream_write_kg_search_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in KG_SEARCH_STEP_DESCRIPTIONS.items():
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
@@ -173,7 +182,7 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in STEP_DESCRIPTIONS.keys():
|
||||
for step_nr in KG_SEARCH_STEP_DESCRIPTIONS.keys():
|
||||
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
@@ -195,7 +204,40 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_close_step_answer(
|
||||
def stream_write_basic_search_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in BASIC_SEARCH_STEP_DESCRIPTIONS.items():
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=step_detail.description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in BASIC_SEARCH_STEP_DESCRIPTIONS:
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_kg_search_close_step_answer(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
@@ -207,7 +249,7 @@ def stream_close_step_answer(
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
def stream_write_kg_search_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
@@ -355,7 +397,7 @@ def get_near_empty_step_results(
|
||||
Get near-empty step results from a list of step results.
|
||||
"""
|
||||
return SubQuestionAnswerResults(
|
||||
question=STEP_DESCRIPTIONS[step_number].description,
|
||||
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
|
||||
question_id="0_" + str(step_number),
|
||||
answer=step_answer,
|
||||
verified_high_quality=True,
|
||||
|
||||
@@ -7,17 +7,23 @@ from langgraph.types import StreamWriter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_structure,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@@ -42,7 +48,7 @@ logger = setup_logger()
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
) -> EntityRelationshipExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
@@ -68,17 +74,17 @@ def extract_ert(
|
||||
user_name = user_email.split("@")[0] or "unknown"
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_step_structure(writer)
|
||||
if state.individual_flow:
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_kg_search_structure(writer)
|
||||
|
||||
# Now specify core activities in the step (step 1)
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -240,12 +246,13 @@ def extract_ert(
|
||||
step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
|
||||
# Finish Step 1
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
# Finish Step 1
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
return EntityRelationshipExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
relationship_types_str=all_relationship_types,
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
|
||||
@@ -9,10 +9,14 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
@@ -141,7 +145,7 @@ def analyze(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
entities = (
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
@@ -150,7 +154,8 @@ def analyze(
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Continue with node
|
||||
|
||||
@@ -277,9 +282,12 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
else:
|
||||
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=step_answer
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# End node
|
||||
|
||||
|
||||
@@ -8,10 +8,14 @@ from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_activities,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
@@ -33,8 +37,10 @@ from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_
|
||||
from onyx.db.kg_temp_view import drop_views
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import ENTITY_TABLE_DESCRIPTION
|
||||
from onyx.prompts.kg_prompts import RELATIONSHIP_TABLE_DESCRIPTION
|
||||
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_ERROR_FIX_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -122,6 +128,22 @@ def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _run_sql(
|
||||
sql_statement: str, rel_temp_view: str, ent_temp_view: str
|
||||
) -> list[dict[str, Any]]:
|
||||
# check sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(sql_statement, rel_temp_view, ent_temp_view)
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
return [{"count": int(scalar_result)}] if scalar_result is not None else []
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
|
||||
|
||||
def _get_source_documents(
|
||||
sql_statement: str,
|
||||
llm: LLM,
|
||||
@@ -189,7 +211,7 @@ def generate_simple_sql(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
entities_types_str = state.entities_types_str
|
||||
relationship_types_str = state.relationship_types_str
|
||||
|
||||
@@ -199,7 +221,6 @@ def generate_simple_sql(
|
||||
raise ValueError("kg_doc_temp_view_name is not set")
|
||||
if state.kg_rel_temp_view_name is None:
|
||||
raise ValueError("kg_rel_temp_view_name is not set")
|
||||
|
||||
if state.kg_entity_temp_view_name is None:
|
||||
raise ValueError("kg_entity_temp_view_name is not set")
|
||||
|
||||
@@ -207,7 +228,8 @@ def generate_simple_sql(
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_activities(writer, _KG_STEP_NR)
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
@@ -270,6 +292,12 @@ def generate_simple_sql(
|
||||
)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
else:
|
||||
simple_sql_prompt = (
|
||||
@@ -289,8 +317,7 @@ def generate_simple_sql(
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
# prepare SQL query generation
|
||||
|
||||
# generate initial sql statement
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
@@ -298,7 +325,6 @@ def generate_simple_sql(
|
||||
]
|
||||
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
@@ -336,53 +362,6 @@ def generate_simple_sql(
|
||||
)
|
||||
raise e
|
||||
|
||||
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
|
||||
# Correction if needed:
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
)
|
||||
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
# display sql statement with view names replaced by general view names
|
||||
sql_statement_display = sql_statement.replace(
|
||||
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
|
||||
@@ -437,51 +416,93 @@ def generate_simple_sql(
|
||||
|
||||
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
query_results = [] # if no results, will be empty (not None)
|
||||
query_generation_error = None
|
||||
|
||||
# check sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
sql_statement, rel_temp_view, ent_temp_view
|
||||
)
|
||||
# run sql
|
||||
try:
|
||||
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
|
||||
if not query_results:
|
||||
query_generation_error = "SQL query returned no results"
|
||||
logger.warning(f"{query_generation_error}, retrying...")
|
||||
except Exception as e:
|
||||
query_generation_error = str(e)
|
||||
logger.warning(f"Error executing SQL query: {e}, retrying...")
|
||||
|
||||
# fix sql and try one more time if sql query didn't work out
|
||||
# if the result is still empty after this, the kg probably doesn't have the answer,
|
||||
# so we update the strategy to simple and address this in the answer generation
|
||||
if query_generation_error is not None:
|
||||
sql_fix_prompt = (
|
||||
SIMPLE_SQL_ERROR_FIX_PROMPT.replace(
|
||||
"---table_description---",
|
||||
(
|
||||
ENTITY_TABLE_DESCRIPTION
|
||||
if state.query_type
|
||||
== KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
else RELATIONSHIP_TABLE_DESCRIPTION
|
||||
),
|
||||
)
|
||||
.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---sql_statement---", sql_statement)
|
||||
.replace("---error_message---", query_generation_error)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
msg = [HumanMessage(content=sql_fix_prompt)]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result)}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
|
||||
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
)
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
sql_statement = sql_statement.replace(
|
||||
"relationship_table", rel_temp_view
|
||||
)
|
||||
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
|
||||
|
||||
reasoning = (
|
||||
cleaned_response.split("<reasoning>")[1]
|
||||
.strip()
|
||||
.split("</reasoning>")[0]
|
||||
)
|
||||
|
||||
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query even after retry: {e}")
|
||||
# TODO: raise error on frontend
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
raise
|
||||
|
||||
source_document_results = None
|
||||
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
|
||||
# check source document sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
source_documents_sql, rel_temp_view, ent_temp_view
|
||||
)
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
|
||||
try:
|
||||
result = db_session.execute(text(source_documents_sql))
|
||||
rows = result.fetchall()
|
||||
@@ -491,28 +512,16 @@ def generate_simple_sql(
|
||||
for source_document_result in query_source_document_results
|
||||
]
|
||||
except Exception as e:
|
||||
# TODO: raise error on frontend
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
# TODO: raise warning on frontend
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
|
||||
elif state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
# source documents should be returned for entity queries
|
||||
source_document_results = [
|
||||
x["source_document"] for x in query_results if "source_document" in x
|
||||
]
|
||||
else:
|
||||
|
||||
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
# source documents should be returned for entity queries
|
||||
source_document_results = [
|
||||
x["source_document"]
|
||||
for x in query_results
|
||||
if "source_document" in x
|
||||
]
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
source_document_results = None
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
@@ -528,21 +537,25 @@ def generate_simple_sql(
|
||||
|
||||
main_sql_statement = sql_statement
|
||||
|
||||
if reasoning:
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
|
||||
if reasoning and state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=reasoning
|
||||
)
|
||||
|
||||
if sql_statement_display:
|
||||
stream_write_step_answer_explicit(
|
||||
if sql_statement_display and state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer,
|
||||
step_nr=_KG_STEP_NR,
|
||||
answer=f" \n Generated SQL: {sql_statement_display}",
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
if state.individual_flow:
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# Update path if too many results are retrieved
|
||||
|
||||
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
|
||||
# Update path if too many, or no results were retrieved from sql
|
||||
if main_sql_statement and (
|
||||
not query_results or len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS
|
||||
):
|
||||
updated_strategy = KGAnswerStrategy.SIMPLE
|
||||
else:
|
||||
updated_strategy = None
|
||||
|
||||
@@ -34,7 +34,7 @@ def construct_deep_search_filters(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities_no_attributes
|
||||
@@ -155,7 +155,11 @@ def construct_deep_search_filters(
|
||||
|
||||
if div_con_structure:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
|
||||
# entity_type is guaranteed to have grounded_source_name
|
||||
if (
|
||||
cast(str, entity_type.grounded_source_name).lower()
|
||||
in div_con_structure[0].lower()
|
||||
):
|
||||
source_division = True
|
||||
break
|
||||
|
||||
|
||||
@@ -98,16 +98,17 @@ def process_individual_deep_search(
|
||||
kg_relationship_filters = None
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
|
||||
logger.debug(
|
||||
|
||||
@@ -7,9 +7,11 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
@@ -49,7 +51,7 @@ def filtered_search(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
@@ -72,17 +74,18 @@ def filtered_search(
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
@@ -165,11 +168,12 @@ def filtered_search(
|
||||
|
||||
step_answer = "Filtered search is complete."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=filtered_search_answer,
|
||||
|
||||
@@ -5,9 +5,11 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
@@ -41,11 +43,12 @@ def consolidate_individual_deep_search(
|
||||
|
||||
step_answer = "All research is complete. Consolidating results..."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
|
||||
@@ -4,9 +4,11 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
stream_kg_search_close_step_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
@@ -66,28 +68,26 @@ def process_kg_only_answers(
|
||||
|
||||
# we use this stream write explicitly
|
||||
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
query_results_list = []
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
)
|
||||
query_results_data_str = "\n".join(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
for query_result in query_results
|
||||
)
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
logger.warning("No query results were found")
|
||||
query_results_data_str = "(No query results were found)"
|
||||
|
||||
source_reference_result_str = _get_formated_source_reference_results(
|
||||
source_document_results
|
||||
@@ -99,9 +99,12 @@ def process_kg_only_answers(
|
||||
"No further research is needed, the answer is derived from the knowledge graph."
|
||||
)
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_answer_explicit(
|
||||
writer, step_nr=_KG_STEP_NR, answer=step_answer
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
|
||||
@@ -7,14 +7,17 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.access.access import get_acl_for_user
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_kg_search_close_steps,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -42,7 +45,7 @@ logger = setup_logger()
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
) -> FinalAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
@@ -50,7 +53,9 @@ def generate_answer(
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
question = state.question
|
||||
|
||||
final_answer: str | None = None
|
||||
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
@@ -69,7 +74,8 @@ def generate_answer(
|
||||
|
||||
# DECLARE STEPS DONE
|
||||
|
||||
stream_write_close_steps(writer)
|
||||
if state.individual_flow:
|
||||
stream_write_kg_search_close_steps(writer)
|
||||
|
||||
## MAIN ANSWER
|
||||
|
||||
@@ -128,16 +134,17 @@ def generate_answer(
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
if state.individual_flow:
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# continue with the answer generation
|
||||
|
||||
@@ -206,24 +213,40 @@ def generate_answer(
|
||||
)
|
||||
]
|
||||
try:
|
||||
run_with_timeout(
|
||||
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
if state.individual_flow:
|
||||
|
||||
stream_results: tuple[list[str], list[float]] = run_with_timeout(
|
||||
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
final_answer = "".join(stream_results[0])
|
||||
else:
|
||||
final_answer = get_answer_from_llm(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=output_format_prompt,
|
||||
stream=False,
|
||||
json_string_flag=False,
|
||||
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
return MainOutput(
|
||||
return FinalAnswerUpdate(
|
||||
final_answer=final_answer,
|
||||
retrieved_documents=answer_generation_documents.context_documents,
|
||||
step_results=[],
|
||||
remarks=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -48,6 +48,8 @@ def log_data(
|
||||
)
|
||||
|
||||
return MainOutput(
|
||||
final_answer=state.final_answer,
|
||||
retrieved_documents=state.retrieved_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -120,7 +120,7 @@ class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
class EntityRelationshipExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
relationship_types_str: str = ""
|
||||
extracted_entities_w_attributes: list[str] = []
|
||||
@@ -144,7 +144,13 @@ class ResearchObjectUpdate(LoggerUpdate):
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
question: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
|
||||
class FinalAnswerUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
retrieved_documents: list[InferenceSection] | None = None
|
||||
|
||||
|
||||
## Graph State
|
||||
@@ -154,7 +160,7 @@ class MainState(
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
EntityRelationshipExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
@@ -162,6 +168,7 @@ class MainState(
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
FinalAnswerUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -169,6 +176,8 @@ class MainState(
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
retrieved_documents: list[InferenceSection] | None
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
@@ -179,3 +188,4 @@ class ResearchObjectInput(LoggerUpdate):
|
||||
source_division: bool | None
|
||||
source_entity_filters: list[str] | None
|
||||
segment_type: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
|
||||
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(
|
||||
description="Analyzing the question...",
|
||||
activities=[
|
||||
@@ -27,3 +27,7 @@ STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
description="Conducting further research on source documents...", activities=[]
|
||||
),
|
||||
}
|
||||
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(description="Conducting a standard search...", activities=[]),
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.models import Persona
|
||||
@@ -72,6 +73,7 @@ class GraphSearchConfig(BaseModel):
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
research_type: ResearchType = ResearchType.THOUGHTFUL
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -271,7 +271,10 @@ def choose_tool(
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
)
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("No tool message emitted by LLM")
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
|
||||
@@ -4,6 +4,7 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
@@ -21,6 +22,7 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -62,7 +64,9 @@ def basic_use_tool_response(
|
||||
for section in dedupe_documents(search_response_summary.top_sections)[0]
|
||||
]
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
new_tool_call_chunk = BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=AIMessageChunk(content=""), full_answer=None
|
||||
)
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
@@ -80,4 +84,9 @@ def basic_use_tool_response(
|
||||
displayed_search_results=initial_search_results or final_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
return BasicOutput(
|
||||
tool_call_chunk=new_tool_call_chunk.ai_message_chunk,
|
||||
full_answer=new_tool_call_chunk.full_answer,
|
||||
cited_references=new_tool_call_chunk.cited_references,
|
||||
retrieved_documents=new_tool_call_chunk.retrieved_documents,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
|
||||
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -41,6 +43,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GraphInput = BasicInput | MainInput | DCMainInput | KBMainInput | DRMainInput
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
@@ -90,7 +94,7 @@ def _parse_agent_event(
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
graph_input: GraphInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -104,7 +108,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
input: GraphInput,
|
||||
) -> AnswerStream:
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
@@ -154,7 +158,24 @@ def run_kb_graph(
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
input = KBMainInput(
|
||||
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
|
||||
)
|
||||
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
|
||||
)
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dr_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = dr_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = DRMainInput(log_messages=[])
|
||||
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# match ```json{...}``` or ```{...}```
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
def stream_llm_answer(
|
||||
@@ -66,3 +82,111 @@ def stream_llm_answer(
|
||||
response.append(content)
|
||||
|
||||
return response, dispatch_timings
|
||||
|
||||
|
||||
def invoke_llm_json(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
schema: Type[SchemaType],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
|
||||
or []
|
||||
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
|
||||
|
||||
response_content = str(
|
||||
llm.invoke(
|
||||
prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
**cast(
|
||||
dict, {"structured_response_format": schema} if supports_json else {}
|
||||
),
|
||||
).content
|
||||
)
|
||||
|
||||
if not supports_json:
|
||||
# remove newlines as they often lead to json decoding errors
|
||||
response_content = response_content.replace("\n", " ")
|
||||
# hope the prompt is structured in a way a json is outputted...
|
||||
json_block_match = JSON_PATTERN.search(response_content)
|
||||
if json_block_match:
|
||||
response_content = json_block_match.group(1)
|
||||
else:
|
||||
first_bracket = response_content.find("{")
|
||||
last_bracket = response_content.rfind("}")
|
||||
response_content = response_content[first_bracket : last_bracket + 1]
|
||||
|
||||
return schema.model_validate_json(response_content)
|
||||
|
||||
|
||||
def get_answer_from_llm(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
timeout: int = 25,
|
||||
timeout_override: int = 5,
|
||||
max_tokens: int = 500,
|
||||
stream: bool = False,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
agent_answer_level: int = 0,
|
||||
agent_answer_question_num: int = 0,
|
||||
agent_answer_type: Literal[
|
||||
"agent_sub_answer", "agent_level_answer"
|
||||
] = "agent_level_answer",
|
||||
json_string_flag: bool = False,
|
||||
) -> str:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=prompt,
|
||||
)
|
||||
]
|
||||
|
||||
if stream:
|
||||
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
|
||||
stream_response, _ = run_with_timeout(
|
||||
timeout,
|
||||
lambda: stream_llm_answer(
|
||||
llm=llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=agent_answer_level,
|
||||
agent_answer_question_num=agent_answer_question_num,
|
||||
agent_answer_type=agent_answer_type,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
content = "".join(stream_response)
|
||||
else:
|
||||
llm_response = run_with_timeout(
|
||||
timeout,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
content = str(llm_response.content)
|
||||
|
||||
cleaned_response = content
|
||||
if json_string_flag:
|
||||
cleaned_response = (
|
||||
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
return cleaned_response
|
||||
|
||||
39
backend/onyx/agents/agent_search/utils.py
Normal file
39
backend/onyx/agents/agent_search/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def create_citation_format_list(
|
||||
document_citations: list[InferenceSection],
|
||||
) -> list[dict[str, Any]]:
|
||||
citation_list: list[dict[str, Any]] = []
|
||||
for document_citation in document_citations:
|
||||
document_citation_dict = {
|
||||
"link": "",
|
||||
"blurb": document_citation.center_chunk.blurb,
|
||||
"content": document_citation.center_chunk.content,
|
||||
"metadata": document_citation.center_chunk.metadata,
|
||||
"updated_at": str(document_citation.center_chunk.updated_at),
|
||||
"document_id": document_citation.center_chunk.document_id,
|
||||
"source_type": "file",
|
||||
"source_links": document_citation.center_chunk.source_links,
|
||||
"match_highlights": document_citation.center_chunk.match_highlights,
|
||||
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
|
||||
}
|
||||
|
||||
citation_list.append(document_citation_dict)
|
||||
|
||||
return citation_list
|
||||
|
||||
|
||||
def create_question_prompt(
|
||||
system_prompt: str | None, human_prompt: str
|
||||
) -> list[BaseMessage]:
|
||||
return [
|
||||
SystemMessage(content=system_prompt or ""),
|
||||
HumanMessage(content=human_prompt),
|
||||
]
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
@@ -12,7 +13,7 @@ from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dr_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
@@ -27,6 +28,7 @@ from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Persona
|
||||
@@ -124,6 +126,9 @@ class Answer:
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
|
||||
kg_config_settings=get_kg_config_settings(),
|
||||
research_type=(
|
||||
ResearchType.DEEP if use_agentic_search else ResearchType.THOUGHTFUL
|
||||
),
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
@@ -138,12 +143,16 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search and (
|
||||
# TODO: add toggle in UI with customizable TimeBudget
|
||||
if (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
and self.graph_config.inputs.persona.name.startswith("KG Beta")
|
||||
and self.graph_config.inputs.persona.name.startswith(
|
||||
TMP_DRALPHA_PERSONA_NAME
|
||||
)
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
run_langgraph = run_dr_graph
|
||||
|
||||
elif self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_agent_search_graph
|
||||
elif (
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -401,7 +402,7 @@ def process_kg_commands(
|
||||
) -> None:
|
||||
# Temporarily, until we have a draft UI for the KG Operations/Management
|
||||
# TODO: move to api endpoint once we get frontend
|
||||
if not persona_name.startswith("KG Beta"):
|
||||
if not persona_name.startswith(TMP_DRALPHA_PERSONA_NAME):
|
||||
return
|
||||
|
||||
kg_config_settings = get_kg_config_settings()
|
||||
|
||||
@@ -54,6 +54,7 @@ from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
@@ -845,6 +846,18 @@ def stream_chat_message_objects(
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage:
|
||||
|
||||
is_kg_beta = parent_message.chat_session.persona.name.startswith(
|
||||
TMP_DRALPHA_PERSONA_NAME
|
||||
)
|
||||
is_basic_search = tool_call and tool_call.tool_name == SearchTool._NAME
|
||||
is_agentic_overwrite = new_msg_req.use_agentic_search and not (
|
||||
is_kg_beta and is_basic_search
|
||||
)
|
||||
|
||||
if is_kg_beta:
|
||||
is_agentic_overwrite = False
|
||||
|
||||
return create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=(
|
||||
@@ -867,7 +880,7 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
is_agentic=is_agentic_overwrite,
|
||||
)
|
||||
|
||||
partial_response = create_response
|
||||
|
||||
@@ -3,6 +3,7 @@ import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -138,6 +139,8 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Onyx APIs without specifying a source type
|
||||
@@ -512,3 +515,57 @@ else:
|
||||
class OnyxCallTypes(str, Enum):
|
||||
FIREFLIES = "FIREFLIES"
|
||||
GONG = "GONG"
|
||||
|
||||
|
||||
# TODO: this should be stored likely in database
|
||||
DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
# Special case, document passed in via Onyx APIs without specifying a source type
|
||||
DocumentSource.INGESTION_API: "ingestion_api",
|
||||
DocumentSource.SLACK: "slack channels",
|
||||
DocumentSource.WEB: "web pages",
|
||||
DocumentSource.GOOGLE_DRIVE: "google drive documents (docs, sheets, etc.)",
|
||||
DocumentSource.GMAIL: "email messages",
|
||||
DocumentSource.REQUESTTRACKER: "requesttracker",
|
||||
DocumentSource.GITHUB: "github data",
|
||||
DocumentSource.GITBOOK: "gitbook data",
|
||||
DocumentSource.GITLAB: "gitlab data",
|
||||
DocumentSource.GURU: "guru data",
|
||||
DocumentSource.BOOKSTACK: "bookstack data",
|
||||
DocumentSource.CONFLUENCE: "confluence data (pages, spaces, etc.)",
|
||||
DocumentSource.JIRA: "jira data (issues, tickets, projects, etc.)",
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
DocumentSource.ZULIP: "zulip data",
|
||||
DocumentSource.LINEAR: "linear data - project management tool, including tickets etc.",
|
||||
DocumentSource.HUBSPOT: "hubspot data - CRM and marketing automation data",
|
||||
DocumentSource.DOCUMENT360: "document360 data",
|
||||
DocumentSource.GONG: "gong - call transcripts",
|
||||
DocumentSource.GOOGLE_SITES: "google_sites - websites",
|
||||
DocumentSource.ZENDESK: "zendesk - customer support data",
|
||||
DocumentSource.LOOPIO: "loopio - rfp data",
|
||||
DocumentSource.DROPBOX: "dropbox - files",
|
||||
DocumentSource.SHAREPOINT: "sharepoint - files",
|
||||
DocumentSource.TEAMS: "teams - chat and collaboration",
|
||||
DocumentSource.SALESFORCE: "salesforce - CRM data",
|
||||
DocumentSource.DISCOURSE: "discourse - discussion forums",
|
||||
DocumentSource.AXERO: "axero - employee engagement data",
|
||||
DocumentSource.CLICKUP: "clickup - project management tool",
|
||||
DocumentSource.MEDIAWIKI: "mediawiki - wiki data",
|
||||
DocumentSource.WIKIPEDIA: "wikipedia - encyclopedia data",
|
||||
DocumentSource.ASANA: "asana",
|
||||
DocumentSource.S3: "s3",
|
||||
DocumentSource.R2: "r2",
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: "google_cloud_storage - cloud storage",
|
||||
DocumentSource.OCI_STORAGE: "oci_storage - cloud storage",
|
||||
DocumentSource.XENFORO: "xenforo - forum data",
|
||||
DocumentSource.DISCORD: "discord - chat and collaboration",
|
||||
DocumentSource.FRESHDESK: "freshdesk - customer support data",
|
||||
DocumentSource.FIREFLIES: "fireflies - call transcripts",
|
||||
DocumentSource.EGNYTE: "egnyte - files",
|
||||
DocumentSource.AIRTABLE: "airtable - database",
|
||||
DocumentSource.HIGHSPOT: "highspot - CRM data",
|
||||
DocumentSource.IMAP: "imap - email data",
|
||||
}
|
||||
|
||||
@@ -140,3 +140,5 @@ KG_MAX_SEARCH_DOCUMENTS: int = int(os.environ.get("KG_MAX_SEARCH_DOCUMENTS", "15
|
||||
KG_MAX_DECOMPOSITION_SEGMENTS: int = int(
|
||||
os.environ.get("KG_MAX_DECOMPOSITION_SEGMENTS", "10")
|
||||
)
|
||||
KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
|
||||
to answer questions"
|
||||
|
||||
0
backend/onyx/configs/research_configs.py
Normal file
0
backend/onyx/configs/research_configs.py
Normal file
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
@@ -23,12 +22,12 @@ from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetr
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_citation_format_list
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc as ServerSearchDoc
|
||||
@@ -1111,27 +1110,6 @@ def log_agent_sub_question_results(
|
||||
primary_message_id: int | None,
|
||||
sub_question_answer_results: list[SubQuestionAnswerResults],
|
||||
) -> None:
|
||||
def _create_citation_format_list(
|
||||
document_citations: list[InferenceSection],
|
||||
) -> list[dict[str, Any]]:
|
||||
citation_list: list[dict[str, Any]] = []
|
||||
for document_citation in document_citations:
|
||||
document_citation_dict = {
|
||||
"link": "",
|
||||
"blurb": document_citation.center_chunk.blurb,
|
||||
"content": document_citation.center_chunk.content,
|
||||
"metadata": document_citation.center_chunk.metadata,
|
||||
"updated_at": str(document_citation.center_chunk.updated_at),
|
||||
"document_id": document_citation.center_chunk.document_id,
|
||||
"source_type": "file",
|
||||
"source_links": document_citation.center_chunk.source_links,
|
||||
"match_highlights": document_citation.center_chunk.match_highlights,
|
||||
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
|
||||
}
|
||||
|
||||
citation_list.append(document_citation_dict)
|
||||
|
||||
return citation_list
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
@@ -1141,7 +1119,7 @@ def log_agent_sub_question_results(
|
||||
]
|
||||
sub_question = sub_question_answer_result.question
|
||||
sub_answer = sub_question_answer_result.answer
|
||||
sub_document_results = _create_citation_format_list(
|
||||
sub_document_results = create_citation_format_list(
|
||||
sub_question_answer_result.context_documents
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -677,8 +678,8 @@ class KGEntityType(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
grounded_source_name: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=False
|
||||
)
|
||||
|
||||
entity_values: Mapped[list[str]] = mapped_column(
|
||||
@@ -2145,6 +2146,11 @@ class ChatMessage(Base):
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
research_type: Mapped[ResearchType] = mapped_column(
|
||||
Enum(ResearchType, native_enum=False), nullable=True
|
||||
)
|
||||
research_plan: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
"""For organizing chat sessions"""
|
||||
@@ -3343,3 +3349,40 @@ class TenantAnonymousUserPath(Base):
|
||||
anonymous_user_path: Mapped[str] = mapped_column(
|
||||
String, nullable=False, unique=True
|
||||
)
|
||||
|
||||
|
||||
class ResearchAgentIteration(Base):
|
||||
__tablename__ = "research_agent_iteration"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE")
|
||||
)
|
||||
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
purpose: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
reasoning: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
|
||||
class ResearchAgentIterationSubStep(Base):
|
||||
__tablename__ = "research_agent_iteration_sub_step"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE")
|
||||
)
|
||||
parent_question_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
iteration_sub_step_nr: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
sub_step_instructions: Mapped[str] = mapped_column(String, nullable=True)
|
||||
sub_step_tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), nullable=True)
|
||||
reasoning: Mapped[str] = mapped_column(String, nullable=True)
|
||||
sub_answer: Mapped[str] = mapped_column(String, nullable=True)
|
||||
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
claims: Mapped[list[str]] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
additional_data: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
@@ -16,7 +16,8 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.tools.built_in_tools import get_builtin_tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -49,9 +50,7 @@ def create_slack_channel_persona(
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
search_tool = get_search_tool(db_session)
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool not found")
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
|
||||
# create/update persona associated with the Slack channel
|
||||
persona_name = _build_persona_name(channel_name)
|
||||
|
||||
@@ -47,15 +47,15 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_classification_extraction_instructions() -> (
|
||||
dict[str, dict[str, KGEntityTypeInstructions]]
|
||||
dict[str | None, dict[str, KGEntityTypeInstructions]]
|
||||
):
|
||||
"""
|
||||
Prepare the classification instructions for the given source.
|
||||
"""
|
||||
|
||||
classification_instructions_dict: dict[str, dict[str, KGEntityTypeInstructions]] = (
|
||||
{}
|
||||
)
|
||||
classification_instructions_dict: dict[
|
||||
str | None, dict[str, KGEntityTypeInstructions]
|
||||
] = {}
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_types = get_entity_types(db_session, active=True)
|
||||
|
||||
@@ -32,9 +32,7 @@ def format_entity_id_for_models(entity_id_name: str) -> str:
|
||||
separator = entity_type = ""
|
||||
|
||||
formatted_entity_type = entity_type.strip().upper()
|
||||
formatted_entity_name = (
|
||||
entity_name.strip().replace('"', "").replace("'", "").title()
|
||||
)
|
||||
formatted_entity_name = entity_name.strip().replace('"', "").replace("'", "")
|
||||
|
||||
return f"{formatted_entity_type}{separator}{formatted_entity_name}"
|
||||
|
||||
|
||||
1118
backend/onyx/prompts/dr_prompts.py
Normal file
1118
backend/onyx/prompts/dr_prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -669,8 +669,8 @@ that should be used to analyze each object/each source (or 'the object' that fit
|
||||
}}
|
||||
|
||||
Do not include any other text or explanations.
|
||||
|
||||
"""
|
||||
|
||||
SOURCE_DETECTION_PROMPT = f"""
|
||||
You are an expert in generating, understanding and analyzing SQL statements.
|
||||
|
||||
@@ -773,11 +773,29 @@ Please structure your answer using <reasoning>, </reasoning>,<sql>, </sql> start
|
||||
""".strip()
|
||||
|
||||
|
||||
SIMPLE_SQL_PROMPT = f"""
|
||||
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
|
||||
between TWO ENTITIES. The table has the following structure:
|
||||
ENTITY_TABLE_DESCRIPTION = f"""\
|
||||
- Table name: entity_table
|
||||
- Columns:
|
||||
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
|
||||
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
|
||||
- entity_type (str): the type of the entity [example: ACCOUNT].
|
||||
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
|
||||
- source_document (str): the id of the document that contains the entity. Note that the combination of \
|
||||
id_name and source_document IS UNIQUE!
|
||||
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
"""
|
||||
|
||||
RELATIONSHIP_TABLE_DESCRIPTION = f"""\
|
||||
- Table name: relationship_table
|
||||
- Columns:
|
||||
- relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \
|
||||
@@ -803,17 +821,27 @@ id_name and source_document IS UNIQUE!
|
||||
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided.
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>:
|
||||
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>.
|
||||
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by '::<entity_name>' \
|
||||
in the relationship id as shown above.
|
||||
{SEPARATOR_LINE}
|
||||
---relationship_types---
|
||||
{SEPARATOR_LINE}
|
||||
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by ':<entity_name>' in the \
|
||||
relationship id as shown above..
|
||||
"""
|
||||
|
||||
|
||||
SIMPLE_SQL_PROMPT = f"""
|
||||
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
|
||||
between TWO ENTITIES. The table has the following structure:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
{RELATIONSHIP_TABLE_DESCRIPTION}
|
||||
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
@@ -936,7 +964,7 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
""".strip()
|
||||
|
||||
|
||||
# TODO: remove following before merging after enough testing
|
||||
SIMPLE_SQL_CORRECTION_PROMPT = f"""
|
||||
You are an expert in reviewing and fixing SQL statements.
|
||||
|
||||
@@ -949,7 +977,7 @@ Guidance:
|
||||
SELECT statement as well! And it needs to be in the EXACT FORM! So if a \
|
||||
conversion took place, make sure to include the conversion in the SELECT and the ORDER BY clause!
|
||||
- never should 'source_document' be in the SELECT clause! Remove if present!
|
||||
- if there are joins, they must be on entities, never sour ce documents
|
||||
- if there are joins, they must be on entities, never source documents
|
||||
- if there are joins, consider the possibility that the second entity does not exist for all examples.\
|
||||
Therefore consider using LEFT joins (or RIGHT joins) as appropriate.
|
||||
|
||||
@@ -969,26 +997,7 @@ You are an expert in generating a SQL statement that only uses ONE TABLE that ca
|
||||
and their attributes and other data. The table has the following structure:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
- Table name: entity_table
|
||||
- Columns:
|
||||
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
|
||||
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
|
||||
- entity_type (str): the type of the entity [example: ACCOUNT].
|
||||
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
|
||||
- source_document (str): the id of the document that contains the entity. Note that the combination of \
|
||||
id_name and source_document IS UNIQUE!
|
||||
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
|
||||
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
|
||||
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
|
||||
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
|
||||
the entity type may also often be referred to.
|
||||
{SEPARATOR_LINE}
|
||||
---entity_types---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
{ENTITY_TABLE_DESCRIPTION}
|
||||
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
@@ -1077,33 +1086,55 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
""".strip()
|
||||
|
||||
SIMPLE_SQL_ERROR_FIX_PROMPT = f"""
|
||||
You are an expert at fixing SQL statements. You will be provided with a SQL statement that aims to address \
|
||||
a question, but it contains an error. Your task is to fix the SQL statement, based on the error message.
|
||||
|
||||
SQL_AGGREGATION_REMOVAL_PROMPT = f"""
|
||||
You are a SQL expert. You were provided with a SQL statement that returns an aggregation, and you are \
|
||||
tasked to show the underlying objects that were aggregated. For this you need to remove the aggregate functions \
|
||||
from the SQL statement in the correct way.
|
||||
Here is the description of the table that the SQL statement is supposed to use:
|
||||
---table_description---
|
||||
|
||||
Additional rules:
|
||||
- if you see a 'select count(*)', you should NOT convert \
|
||||
that to 'select *...', but rather return the corresponding id_name, entity_type_id_name, name, and document_id. \
|
||||
As in: 'select <table, if necessary>.id_name, <table, if necessary>.entity_type_id_name, \
|
||||
<table, if necessary>.name, <table, if necessary>.document_id ...'. \
|
||||
The id_name is always the primary index, and those should be returned, along with the type (entity_type_id_name), \
|
||||
the name (name) of the objects, and the document_id (document_id) of the object.
|
||||
- Add a limit of 30 to the select statement.
|
||||
- Don't change anything else.
|
||||
- The final select statement needs obviously to be a valid SQL statement.
|
||||
Here is the question you are supposed to translate into a SQL statement:
|
||||
{SEPARATOR_LINE}
|
||||
---question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the SQL statement you are supposed to remove the aggregate functions from:
|
||||
Here is the SQL statement that you should fix:
|
||||
{SEPARATOR_LINE}
|
||||
---sql_statement---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the error message that was returned:
|
||||
{SEPARATOR_LINE}
|
||||
---error_message---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Note that in the case the error states the sql statement did not return any results, it is possible that the \
|
||||
sql statement is correct, but the question is not addressable with the information in the knowledge graph. \
|
||||
If you are absolutely certain that is the case, you may return the original sql statement.
|
||||
|
||||
Here are a couple common errors that you may encounter:
|
||||
- source_document is in the SELECT clause -> remove it
|
||||
- columns used in ORDER BY must also appear in the SELECT DISTINCT clause
|
||||
- consider carefully the type of the columns you are using, especially for attributes. You may need to cast them
|
||||
- dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \
|
||||
So please use that format, particularly if you use data comparisons (>, <, ...)
|
||||
- attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \
|
||||
"attributes ->> '<attribute>' = '<attribute value>'" (or "attributes ? '<attribute>'" to check for existence).
|
||||
- if you are using joins and the sql returned no joins, make sure you are using the appropriate join type (LEFT, RIGHT, etc.) \
|
||||
it is possible that the second entity does not exist for all examples.
|
||||
- (ignore if using entity_table) if using the relationship_table and the sql returned no results, make sure you are \
|
||||
selecting the correct column! Use the available relationship types to determine whether to use the source or target entity.
|
||||
|
||||
APPROACH:
|
||||
Please think through this step by step. Please also bear in mind that the sql statement is written in postgres syntax.
|
||||
|
||||
Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---.
|
||||
|
||||
Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> start and end tags as in:
|
||||
|
||||
<reasoning>[your short step-by step thinking]</reasoning>
|
||||
<sql>[the SQL statement without the aggregate functions]</sql>
|
||||
""".strip()
|
||||
<reasoning>[think through the logic but do so extremely briefly! Not more than 3-4 sentences.]</reasoning>
|
||||
<sql>[the SQL statement that you generate to satisfy the task]</sql>
|
||||
"""
|
||||
|
||||
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT = f"""
|
||||
|
||||
43
backend/onyx/prompts/prompt_template.py
Normal file
43
backend/onyx/prompts/prompt_template.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import re
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""
|
||||
A class for building prompt templates with placeholders.
|
||||
Useful when building templates with json schemas, as {} will not work with f-strings.
|
||||
Unlike string.replace, this class will raise an error if the fields are missing.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = r"---([a-zA-Z0-9_]+)---"
|
||||
|
||||
def __init__(self, template: str, pattern: str = DEFAULT_PATTERN):
|
||||
self._pattern_str = pattern
|
||||
self._pattern = re.compile(pattern)
|
||||
self._template = template
|
||||
self._fields: set[str] = set(self._pattern.findall(template))
|
||||
|
||||
def build(self, **kwargs: str) -> str:
|
||||
"""
|
||||
Build the prompt template with the given fields.
|
||||
Will raise an error if the fields are missing.
|
||||
Will ignore fields that are not in the template.
|
||||
"""
|
||||
missing = self._fields - set(kwargs.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Missing required fields: {missing}.")
|
||||
return self._replace_fields(kwargs)
|
||||
|
||||
def partial_build(self, **kwargs: str) -> "PromptTemplate":
|
||||
"""
|
||||
Returns another PromptTemplate with the given fields replaced.
|
||||
Will ignore fields that are not in the template.
|
||||
"""
|
||||
new_template = self._replace_fields(kwargs)
|
||||
return PromptTemplate(new_template, self._pattern_str)
|
||||
|
||||
def _replace_fields(self, field_vals: dict[str, str]) -> str:
|
||||
def repl(match: re.Match) -> str:
|
||||
key = match.group(1)
|
||||
return field_vals.get(key, match.group(0))
|
||||
|
||||
return self._pattern.sub(repl, self._template)
|
||||
@@ -3,6 +3,8 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
@@ -31,12 +33,13 @@ from onyx.server.kg.models import KGConfig
|
||||
from onyx.server.kg.models import KGConfig as KGConfigAPIModel
|
||||
from onyx.server.kg.models import SourceAndEntityTypeView
|
||||
from onyx.server.kg.models import SourceStatistics
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.tools.built_in_tools import get_builtin_tool
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
_KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
|
||||
to answer questions"
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/kg")
|
||||
|
||||
|
||||
@@ -95,12 +98,9 @@ def enable_or_disable_kg(
|
||||
enable_kg(enable_req=req)
|
||||
populate_missing_default_entity_types__commit(db_session=db_session)
|
||||
|
||||
# Create or restore KG Beta persona
|
||||
|
||||
# Get the search tool
|
||||
search_tool = get_search_tool(db_session=db_session)
|
||||
if not search_tool:
|
||||
raise RuntimeError("SearchTool not found in the database.")
|
||||
# Get the search and knowledge graph tools
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
kg_tool = get_builtin_tool(db_session=db_session, tool_type=KnowledgeGraphTool)
|
||||
|
||||
# Check if we have a previously created persona
|
||||
kg_config_settings = get_kg_config_settings()
|
||||
@@ -132,8 +132,8 @@ def enable_or_disable_kg(
|
||||
is_public = len(user_ids) == 0
|
||||
|
||||
persona_request = PersonaUpsertRequest(
|
||||
name="KG Beta",
|
||||
description=_KG_BETA_ASSISTANT_DESCRIPTION,
|
||||
name=TMP_DRALPHA_PERSONA_NAME,
|
||||
description=KG_BETA_ASSISTANT_DESCRIPTION,
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
@@ -145,7 +145,7 @@ def enable_or_disable_kg(
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
prompt_ids=[0],
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -17,6 +17,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.internet_search.providers import (
|
||||
get_available_providers,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -63,6 +66,15 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
if (bool(get_available_providers()))
|
||||
else []
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=KnowledgeGraphTool,
|
||||
description=(
|
||||
"The Knowledge Graph Search Action allows the assistant to search the knowledge graph for information."
|
||||
"This tool should only be used by the Deep Research Agent, not via tool calling."
|
||||
),
|
||||
in_code_tool_id=KnowledgeGraphTool.__name__,
|
||||
display_name=KnowledgeGraphTool._DISPLAY_NAME,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -106,27 +118,37 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
logger.notice("All built-in tools are loaded/verified.")
|
||||
|
||||
|
||||
def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
||||
def get_builtin_tool(
|
||||
db_session: Session,
|
||||
tool_type: Type[
|
||||
SearchTool | ImageGenerationTool | InternetSearchTool | KnowledgeGraphTool
|
||||
],
|
||||
) -> ToolDBModel:
|
||||
"""
|
||||
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
|
||||
Retrieves a built-in tool from the database based on the tool type.
|
||||
"""
|
||||
search_tool_id = next(
|
||||
tool_id = next(
|
||||
(
|
||||
tool["in_code_tool_id"]
|
||||
for tool in BUILT_IN_TOOLS
|
||||
if tool["cls"].__name__ == SearchTool.__name__
|
||||
if tool["cls"].__name__ == tool_type.__name__
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not search_tool_id:
|
||||
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
|
||||
if not tool_id:
|
||||
raise RuntimeError(
|
||||
f"Tool type {tool_type.__name__} not found in the BUILT_IN_TOOLS list."
|
||||
)
|
||||
|
||||
search_tool = db_session.execute(
|
||||
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
|
||||
db_tool = db_session.execute(
|
||||
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == tool_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
return search_tool
|
||||
if not db_tool:
|
||||
raise RuntimeError(f"Tool type {tool_type.__name__} not found in the database.")
|
||||
|
||||
return db_tool
|
||||
|
||||
|
||||
def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
@@ -136,10 +158,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
Persona objects that were created before the concept of Tools were added.
|
||||
"""
|
||||
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
|
||||
search_tool = get_search_tool(db_session)
|
||||
|
||||
if not search_tool:
|
||||
raise RuntimeError("SearchTool not found in the database.")
|
||||
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
|
||||
|
||||
# Fetch all Personas that need the SearchTool added
|
||||
personas_to_update = (
|
||||
|
||||
@@ -20,6 +20,11 @@ OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||
|
||||
|
||||
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
@@ -41,6 +42,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import compute_all_tool_tokens
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
@@ -265,6 +269,14 @@ def construct_tools(
|
||||
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
|
||||
)
|
||||
|
||||
# Handle KG Tool
|
||||
elif tool_cls.__name__ == KnowledgeGraphTool.__name__:
|
||||
if persona.name != TMP_DRALPHA_PERSONA_NAME:
|
||||
raise ValueError(
|
||||
f"Knowledge Graph Tool should only be used by the '{TMP_DRALPHA_PERSONA_NAME}' Agent."
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [KnowledgeGraphTool()]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
|
||||
@@ -17,6 +17,8 @@ from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -77,6 +79,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
@@ -86,6 +89,7 @@ class CustomTool(BaseTool):
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
self._user_oauth_token = user_oauth_token
|
||||
self._id = id
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
@@ -107,6 +111,10 @@ class CustomTool(BaseTool):
|
||||
if self._user_oauth_token:
|
||||
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
@@ -382,11 +390,27 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
openapi_schema_str = json.dumps(openapi_schema)
|
||||
|
||||
with get_session_with_current_tenant() as temp_db_session:
|
||||
tools = get_tools(temp_db_session)
|
||||
tool_id: int | None = None
|
||||
for tool in tools:
|
||||
if tool.openapi_schema and (
|
||||
json.dumps(tool.openapi_schema) == openapi_schema_str
|
||||
):
|
||||
tool_id = tool.id
|
||||
break
|
||||
if not tool_id:
|
||||
raise ValueError(f"Tool with openapi_schema {openapi_schema_str} not found")
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
method_spec,
|
||||
url,
|
||||
custom_headers,
|
||||
id=tool_id,
|
||||
method_spec=method_spec,
|
||||
base_url=url,
|
||||
custom_headers=custom_headers,
|
||||
user_oauth_token=user_oauth_token,
|
||||
)
|
||||
for method_spec in method_specs
|
||||
|
||||
@@ -13,6 +13,8 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
@@ -112,6 +114,22 @@ class ImageGenerationTool(Tool[None]):
|
||||
self.additional_headers = additional_headers
|
||||
self.output_format = output_format
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == ImageGenerationTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Image Generation tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
@@ -143,8 +144,23 @@ class InternetSearchTool(Tool[None]):
|
||||
)
|
||||
)
|
||||
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == InternetSearchTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Internet Search tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
"""For explicit tool calling"""
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
class KnowledgeGraphTool(Tool[None]):
|
||||
_NAME = "run_kg_search"
|
||||
_DESCRIPTION = "Search the knowledge graph for information. Never call this tool."
|
||||
_DISPLAY_NAME = "Knowledge Graph Search"
|
||||
|
||||
def __init__(self) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == KnowledgeGraphTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError(
|
||||
"Knowledge Graph tool not found. This should never happen."
|
||||
)
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"type": "string",
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
"required": [QUERY_FIELD],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
raise ValueError(
|
||||
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
|
||||
"not via tool calling."
|
||||
)
|
||||
@@ -34,6 +34,7 @@ from onyx.context.search.models import UserFileFilters
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -162,6 +163,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
tool_id: int | None = (
|
||||
db_session.query(ToolDBModel.id)
|
||||
.filter(ToolDBModel.in_code_tool_id == SearchTool.__name__)
|
||||
.scalar()
|
||||
)
|
||||
if not tool_id:
|
||||
raise ValueError("Search tool not found. This should never happen.")
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
Reference in New Issue
Block a user