Compare commits

...

92 Commits

Author SHA1 Message Date
Rei Meguro
ff0c78bb02 fix: aggregation + searched for xyz ui bugfixes 2025-08-13 01:40:58 +09:00
Rei Meguro
73cecf8c05 addresssing a few todos 2025-08-12 23:30:46 +09:00
Rei Meguro
516ae99225 answer and claims 2025-08-12 23:07:40 +09:00
Rei Meguro
1d7ec49e55 correct db sessions 2025-08-12 23:01:42 +09:00
Rei Meguro
de82ad97e0 custom tools 2025-08-12 16:25:19 +09:00
Rei Meguro
7b37e72b9d almost working custom tools 2025-08-11 02:10:48 +09:00
Rei Meguro
09d672ff22 more cleanup for tools 2025-08-10 23:36:24 +09:00
Rei Meguro
b028b25737 rename folder 2025-08-10 21:59:27 +09:00
Rei Meguro
07768d5484 feat: initial custom tool support prep 2025-08-10 18:59:46 +09:00
Rei Meguro
5ca8ca2b1e mypy and proper id implementation 2025-08-10 17:07:41 +09:00
Rei Meguro
62872e58ae cleanup 2025-08-10 16:20:59 +09:00
joachim-danswer
5f66a27c67 ResearchType vs DRTimeBudget 2025-08-08 16:54:34 -07:00
joachim-danswer
c21fa21958 initial decision using tool-calling if tool-calling LLM 2025-08-08 16:37:32 -07:00
joachim-danswer
cd6577c3ca tool_id for custom tools 2025-08-08 12:57:58 -07:00
Rei Meguro
16406f0ebd kg citations 2025-08-08 08:44:25 -07:00
Rei Meguro
4ae5bb1e6b fix: iteration citation replacement 2025-08-08 08:44:25 -07:00
Rei Meguro
b0c95ec876 fix: constants 2025-08-08 08:44:25 -07:00
Rei Meguro
397d30c802 better prompt templating 2025-08-08 08:44:25 -07:00
joachim-danswer
f13b08b461 persistence 2025-08-08 08:44:25 -07:00
Rei Meguro
e66245ec13 properly merge inference section contents 2025-08-08 08:44:25 -07:00
Rei Meguro
c64c6368c1 feat: kg tool proper implementation 2025-08-08 08:44:25 -07:00
joachim-danswer
b2fe55c8f8 more DR updates 2025-08-08 08:44:25 -07:00
joachim-danswer
2b661441d7 nits 2025-08-08 08:44:25 -07:00
joachim-danswer
f83f06228b reworked 'fast' search 2025-08-08 08:44:25 -07:00
joachim-danswer
fabfa8d166 query rejection step 2025-08-08 08:44:25 -07:00
joachim-danswer
994e7f7666 active_source_description 2025-08-08 08:44:25 -07:00
Rei Meguro
c81a7e1ef2 mypy fix 2025-08-08 08:44:25 -07:00
joachim-danswer
1d7d2f06d8 time filter and source prediction 2025-08-08 08:44:25 -07:00
joachim-danswer
916d6cb119 base search in DR refactoring 2025-08-08 08:44:25 -07:00
Rei Meguro
6d3542ded1 fix error overwrite 2025-08-08 08:44:25 -07:00
Rei Meguro
e5dbfc34c3 faster relationship sql generation 2025-08-08 08:44:25 -07:00
Rei Meguro
1aad7f44d2 add back kg 2025-08-08 08:44:25 -07:00
Rei Meguro
a0d6d0b922 cleanup 2025-08-08 08:44:25 -07:00
joachim-danswer
588023a1f6 state updates for internal search 2025-08-08 08:44:25 -07:00
joachim-danswer
e4c2427728 merging of new citation handling and sending back by Closer 2025-08-08 08:44:25 -07:00
joachim-danswer
bf77da26fc closer can suggest more research 2025-08-08 08:44:25 -07:00
Rei Meguro
abfecde097 citation improvements with answer claim structure 2025-08-08 08:44:25 -07:00
Rei Meguro
3f4936ad0a cleanup 2025-08-08 08:44:25 -07:00
joachim-danswer
3b8d16a136 claim improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
322e8668da claim start 2025-08-08 08:44:24 -07:00
Rei Meguro
d1dcad60d6 prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
7b3bdbdf83 fix: mypy 2025-08-08 08:44:24 -07:00
Rei Meguro
8b09fb0cef better clarification (still need prompt work) + prompt template fix 2025-08-08 08:44:24 -07:00
Rei Meguro
a2dd1bbf4f cleanup 2025-08-08 08:44:24 -07:00
Rei Meguro
828231815a fix: mypy 2025-08-08 08:44:24 -07:00
joachim-danswer
d48cbc2b79 custom tools 2025-08-08 08:44:24 -07:00
joachim-danswer
991bd4f8bf separation of tools 2025-08-08 08:44:24 -07:00
Rei Meguro
74418b84a2 consolidate user feedback 2025-08-08 08:44:24 -07:00
joachim-danswer
df1c40c791 prompt spellings 2025-08-08 08:44:24 -07:00
Rei Meguro
c253844500 minor cleanups + mypy fix 2025-08-08 08:44:24 -07:00
joachim-danswer
e972fb3e07 adding current time to prompts 2025-08-08 08:44:24 -07:00
joachim-danswer
726211c27d internet search improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
c0435ddfd6 parallelized internet search 2025-08-08 08:44:24 -07:00
joachim-danswer
48dc934c35 internet search 2025-08-08 08:44:24 -07:00
Rei Meguro
3a575a92d5 fix: incorrect citations 2025-08-08 08:44:24 -07:00
Rei Meguro
de4a9e4687 mypy + rename vars for clarity 2025-08-08 08:44:24 -07:00
joachim-danswer
c330152417 nit 2025-08-08 08:44:24 -07:00
joachim-danswer
dca39f27a6 nit 2025-08-08 08:44:24 -07:00
joachim-danswer
d3cc27846a multi-search for Thoughtful 2025-08-08 08:44:24 -07:00
Rei Meguro
fedc665b88 kg bugfix + fix sql on error + slightly improved dr user feedback prompt 2025-08-08 08:44:24 -07:00
Rei Meguro
614672f357 fix: chat history + question passed to closer 2025-08-08 08:44:24 -07:00
Rei Meguro
6aca9ee005 fix docstring + move shared vars to constants.py 2025-08-08 08:44:24 -07:00
joachim-danswer
f9f64fb1a5 cleaning up of isolating feedback generation 2025-08-08 08:44:24 -07:00
joachim-danswer
4a63e631cd rough - included clarification
TODO: clean up!
2025-08-08 08:44:24 -07:00
Rei Meguro
3d5586d623 feat: make kg query part of state, rather than config 2025-08-08 08:44:24 -07:00
joachim-danswer
6c4eb17b5d prompt improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
0917d9acd3 nits 2025-08-08 08:44:24 -07:00
Rei Meguro
89ea0f8d48 fix: wrong indentation 2025-08-08 08:44:24 -07:00
Rei Meguro
31ae6f1eb1 aggregate context improvements (no duplicates) 2025-08-08 08:44:24 -07:00
Rei Meguro
1b8d246afb feat: preparation for parallel search 2025-08-08 08:44:24 -07:00
Rei Meguro
05e55559d8 feat: citation improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
241b8d062c adding final references 2025-08-08 08:44:24 -07:00
Rei Meguro
6359d2f2d6 formatting 2025-08-08 08:44:24 -07:00
Rei Meguro
83325f9012 feat: previous chat context 2025-08-08 08:44:24 -07:00
joachim-danswer
0b26ed602d updates - KG search w/ citations 2025-08-08 08:44:24 -07:00
Rei Meguro
2b69d1ba52 more minor prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
27cd1d44dc feat: small prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
b5ddf31742 sligtly better planner prompt 2025-08-08 08:44:24 -07:00
Rei Meguro
ce1c80148b final answer streaming 2025-08-08 08:44:24 -07:00
Rei Meguro
bb95c46015 feat: structured response 2025-08-08 08:44:24 -07:00
Rei Meguro
e8a593c315 mypy + better typing 2025-08-08 08:44:24 -07:00
joachim-danswer
bb1b12988c improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
72bbcabedf improved DR 2025-08-08 08:44:24 -07:00
Rei Meguro
2ee98ba795 plan of record fix 2025-08-08 08:44:24 -07:00
Rei Meguro
594bbdb167 greptile + evan comments 2025-08-08 08:44:24 -07:00
Rei Meguro
d5c67b6f50 mypy + typing next_step and plan_of_records 2025-08-08 08:44:24 -07:00
joachim-danswer
9c7638ceba iteration prep 2025-08-08 08:44:24 -07:00
joachim-danswer
b1488ddccc update to KG Beta 2025-08-08 08:44:24 -07:00
joachim-danswer
d9a9818b9a nit 2025-08-08 08:44:24 -07:00
joachim-danswer
4bd3b8b0bb nit 2025-08-08 08:44:24 -07:00
joachim-danswer
da3979fc41 is_agentic_overwrite 2025-08-08 08:44:24 -07:00
joachim-danswer
ffed8b4300 orchestration base 2025-08-08 08:44:24 -07:00
78 changed files with 5480 additions and 368 deletions

View File

@@ -0,0 +1,91 @@
"""add research agent database tables and chat message research fields
Revision ID: 5ae8240accb3
Revises: 62c3a055a141
Create Date: 2025-08-06 14:29:24.691388
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "5ae8240accb3"
down_revision = "62c3a055a141"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_type and research_plan columns to chat_message table
op.add_column(
"chat_message",
sa.Column("research_type", sa.String(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
)
# Create research_agent_iteration table
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Create research_agent_iteration_sub_step table
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"parent_question_id",
sa.Integer(),
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column(
"sub_step_tool_id",
sa.Integer(),
sa.ForeignKey("tool.id"),
nullable=True,
),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
# Drop tables in reverse order
op.drop_table("research_agent_iteration_sub_step")
op.drop_table("research_agent_iteration")
# Remove columns from chat_message table
op.drop_column("chat_message", "research_plan")
op.drop_column("chat_message", "research_type")

View File

@@ -0,0 +1,12 @@
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []

View File

@@ -6,6 +6,9 @@ from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
@@ -18,11 +21,15 @@ class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
unused: bool = True
query_override: str | None = None
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
full_answer: str | None
cited_references: list[InferenceSection] | None
retrieved_documents: list[LlmDoc] | None
## Graph State

View File

@@ -5,6 +5,7 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
@@ -24,7 +25,7 @@ def process_llm_stream(
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
@@ -61,4 +62,6 @@ def process_llm_stream(
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -0,0 +1,49 @@
from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool = state.tools_used[-1]
try:
next_path = DRPath(next_tool)
except ValueError:
next_path = DRPath.GENERIC_TOOL
# handle END
if next_path == DRPath.END:
return END
# handle invalid paths
if next_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_path
in (DRPath.INTERNAL_SEARCH, DRPath.INTERNET_SEARCH, DRPath.KNOWLEDGE_GRAPH)
and len(state.query_list) == 0
):
return DRPath.CLOSER
return next_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return END

View File

@@ -0,0 +1,29 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "HIGH_LEVEL PLAN:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.INTERNET_SEARCH: 1.5,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 6.0,
}

View File

@@ -0,0 +1,114 @@
from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
tool_differentiations: list[str] = []
for tool_1 in available_tools:
for tool_2 in available_tools:
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS:
tool_differentiations.append(
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
)
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
if DRPath.KNOWLEDGE_GRAPH.value in available_tools:
if not entity_types_string or not relationship_types_string:
raise ValueError(
"Entity types and relationship types must be provided if the Knowledge Graph is used."
)
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string,
possible_relationships=relationship_types_string,
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -0,0 +1,20 @@
from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# BASIC = "BASIC"
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
class DRPath(str, Enum):
CLARIFIER = "CLARIFIER"
ORCHESTRATOR = "ORCHESTRATOR"
INTERNAL_SEARCH = "INTERNAL_SEARCH"
GENERIC_TOOL = "GENERIC_TOOL"
KNOWLEDGE_GRAPH = "KNOWLEDGE_GRAPH"
INTERNET_SEARCH = "INTERNET_SEARCH"
CLOSER = "CLOSER"
END = "END"

View File

@@ -0,0 +1,73 @@
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
dr_is_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
logger = setup_logger()
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_is_graph_builder().compile()
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
return graph

View File

@@ -0,0 +1,115 @@
from enum import Enum
from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
class Config:
arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
global_iteration_responses: list[IterationAnswer]
class ResearchType(str, Enum):
"""Time budget options for agent search operations"""
FAST = "fast"
SHALLOW = "shallow"
DEEP = "deep"
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str

View File

@@ -0,0 +1,461 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
from onyx.agents.agent_search.dr.constants import CLARIFICATION_REQUEST_PREFIX
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import AgentAnswerPiece
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import MessageType
from onyx.db.connector import fetch_unique_document_sources
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _format_tool_name(tool_name: str) -> str:
"""Convert tool name to LLM-friendly format."""
name = tool_name.replace(" ", "_")
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
return name.upper()
def _get_available_tools(
graph_config: GraphConfig, kg_enabled: bool
) -> dict[str, OrchestratorTool]:
available_tools: dict[str, OrchestratorTool] = {}
for tool in graph_config.tooling.tools:
tool_info = OrchestratorTool(
tool_id=tool.id,
name=tool.name,
llm_path=_format_tool_name(tool.name),
path=DRPath.GENERIC_TOOL,
description=tool.description,
metadata={},
cost=1.0,
tool_object=tool,
)
if isinstance(tool, CustomTool):
# tool_info.metadata["summary_signature"] = CUSTOM_TOOL_RESPONSE_ID
pass
elif isinstance(tool, InternetSearchTool):
# tool_info.metadata["summary_signature"] = (
# INTERNET_SEARCH_RESPONSE_SUMMARY_ID
# )
tool_info.llm_path = DRPath.INTERNET_SEARCH.value
tool_info.path = DRPath.INTERNET_SEARCH
elif isinstance(tool, SearchTool):
# tool_info.metadata["summary_signature"] = SEARCH_RESPONSE_SUMMARY_ID
tool_info.llm_path = DRPath.INTERNAL_SEARCH.value
tool_info.path = DRPath.INTERNAL_SEARCH
elif isinstance(tool, KnowledgeGraphTool):
if not kg_enabled:
logger.warning("KG must be enabled to use KG search tool, skipping")
continue
tool_info.llm_path = DRPath.KNOWLEDGE_GRAPH.value
tool_info.path = DRPath.KNOWLEDGE_GRAPH
else:
logger.warning(f"Tool {tool.name} ({type(tool)}) is not supported")
continue
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
available_tools[tool_info.llm_path] = tool_info
# make sure KG isn't enabled without internal search
if (
DRPath.KNOWLEDGE_GRAPH.value in available_tools
and DRPath.INTERNAL_SEARCH.value not in available_tools
):
raise ValueError(
"The Knowledge Graph is not supported without internal search tool"
)
# add CLOSER tool, which is always available
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
tool_id=-1,
name="closer",
llm_path=DRPath.CLOSER.value,
path=DRPath.CLOSER,
description=TOOL_DESCRIPTION[DRPath.CLOSER],
metadata={},
cost=0.0,
tool_object=None,
)
return available_tools
def _get_existing_clarification_request(
graph_config: GraphConfig,
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
"""
Returns the clarification info, original question, and updated chat history if
a clarification request and response exists, otherwise returns None.
"""
# check for clarification request and response in message history
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
if (
len(previous_raw_messages) < 2
or previous_raw_messages[-1].message_type != MessageType.ASSISTANT
or CLARIFICATION_REQUEST_PREFIX not in previous_raw_messages[-1].message
):
return None
# get the clarification request and response
previous_messages = graph_config.inputs.prompt_builder.message_history
last_message = previous_raw_messages[-1].message
clarification = OrchestrationClarificationInfo(
clarification_question=last_message.split(CLARIFICATION_REQUEST_PREFIX, 1)[
1
].strip(),
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
)
original_question = graph_config.inputs.prompt_builder.raw_user_query
chat_history_string = "(No chat history yet available)"
# get the original user query and chat history string before the original query
# e.g., if history = [user query, assistant clarification request, user clarification response],
# previous_messages = [user query, assistant clarification request], we want the user query
for i, message in enumerate(reversed(previous_messages), 1):
if (
isinstance(message, HumanMessage)
and message.content
and isinstance(message.content, str)
):
original_question = message.content
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history[:-i],
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
break
return clarification, original_question, chat_history_string
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
"name": "run_any_knowledge_retrieval_and_any_action_tool",
"description": "Use this tool to get any external information \
that is relevant to the question, or for any action to be taken.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The request to be made to the tool",
},
},
"required": ["request"],
},
},
}
def clarifier(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationUpdate:
"""
Perform a quick search on the question as is and see whether a set of clarification
questions is needed. For now this is based on the models
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
original_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# get the connected tools and format for the Deep Research flow
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
available_tools = _get_available_tools(graph_config, kg_enabled)
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
db_session = graph_config.persistence.db_session
active_source_types = fetch_unique_document_sources(db_session)
if not active_source_types:
raise ValueError("No active source types found")
active_source_types_descriptions = [
DocumentSourceDescription[source_type] for source_type in active_source_types
]
# Quick evaluation whether the query should be answered or rejected
# query_evaluation_prompt = QUERY_EVALUATION_PROMPT.replace(
# "---query---", original_question
# )
# try:
# evaluation_response = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=query_evaluation_prompt,
# schema=QueryEvaluationResponse,
# timeout_override=10,
# max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Error in clarification generation: {e}")
# raise e
# if not evaluation_response.query_permitted:
# rejection_prompt = QUERY_REJECTION_PROMPT.replace(
# "---query---", original_question
# ).replace("---reasoning---", evaluation_response.reasoning)
# _ = run_with_timeout(
# 80,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=rejection_prompt,
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=60,
# max_tokens=None,
# ),
# )
# logger.info(
# f"""Rejected query: {original_question}, \
# Rejection reason: {evaluation_response.reasoning}"""
# )
# return OrchestrationUpdate(
# original_question=original_question,
# chat_history_string="",
# tools_used=[DRPath.END.value],
# query_list=[],
# )
# get clarification (unless time budget is FAST)
# Verification of whether the LLM can answer the question without any external tool or knowledge
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
if not use_tool_calling_llm:
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
question=original_question, chat_history_string=chat_history_string
)
initial_decision_tokens, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING, decision_prompt
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
max_tokens=None,
),
)
initial_decision_str = cast(str, merge_content(*initial_decision_tokens))
if len(initial_decision_str.replace(" ", "")) > 0:
return OrchestrationUpdate(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
)
else:
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
question=original_question, chat_history_string=chat_history_string
)
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
EVAL_SYSTEM_PROMPT_W_TOOL_CALLING, decision_prompt
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
tool_message = process_llm_stream(
stream,
True,
writer,
).ai_message_chunk
if tool_message is None or len(tool_message.tool_calls) == 0:
return OrchestrationUpdate(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
)
# Continue, as external knowledge is required.
clarification = None
if research_type != ResearchType.THOUGHTFUL:
result = _get_existing_clarification_request(graph_config)
if result is not None:
clarification, original_question, chat_history_string = result
else:
# generate clarification questions if needed
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
base_clarification_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.CLARIFICATION,
research_type,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
clarification_prompt = base_clarification_prompt.build(
question=original_question,
chat_history_string=chat_history_string,
)
try:
clarification_response = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=clarification_prompt,
schema=ClarificationGenerationResponse,
timeout_override=25,
# max_tokens=1500,
)
except Exception as e:
logger.error(f"Error in clarification generation: {e}")
raise e
if (
clarification_response.clarification_needed
and clarification_response.clarification_question
):
clarification = OrchestrationClarificationInfo(
clarification_question=clarification_response.clarification_question,
clarification_response=None,
)
write_custom_event(
"basic_response",
AgentAnswerPiece(
answer_piece=(
f"{CLARIFICATION_REQUEST_PREFIX} "
f"{clarification.clarification_question}\n\n"
),
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
else:
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
if (
clarification
and clarification.clarification_question
and clarification.clarification_response is None
):
next_tool = DRPath.END.value
else:
next_tool = DRPath.ORCHESTRATOR.value
return OrchestrationUpdate(
original_question=original_question,
chat_history_string=chat_history_string,
tools_used=[next_tool],
query_list=[],
iteration_nr=0,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="clarifier",
node_start_time=node_start_time,
)
],
clarification=clarification,
available_tools=available_tools,
active_source_types=active_source_types,
active_source_types_descriptions="\n".join(active_source_types_descriptions),
)

View File

@@ -0,0 +1,364 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
from onyx.agents.agent_search.dr.states import IterationInstructions
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import create_tool_call_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
from onyx.utils.logger import setup_logger
logger = setup_logger()
def orchestrator(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationUpdate:
"""
LangGraph node to decide the next step in the DR process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.original_question
if not question:
raise ValueError("Question is required for orchestrator")
plan_of_record = state.plan_of_record
clarification = state.clarification
iteration_nr = state.iteration_nr + 1
research_type = graph_config.behavior.research_type
remaining_time_budget = state.remaining_time_budget
chat_history_string = state.chat_history_string or "(No chat history yet available)"
answer_history_string = (
aggregate_context(state.iteration_responses, include_documents=True).context
or "(No answer history yet available)"
)
available_tools = state.available_tools or {}
questions = [
f"{iteration_response.tool}: {iteration_response.question}"
for iteration_response in state.iteration_responses
if len(iteration_response.question) > 0
]
question_history_string = (
"\n".join(f" - {question}" for question in questions)
if questions
else "(No question history yet available)"
)
prompt_question = get_prompt_question(question, clarification)
gaps_str = (
("\n - " + "\n - ".join(state.gaps))
if state.gaps
else "(No explicit gaps were pointed out so far)"
)
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# default to closer
next_tool = DRPath.CLOSER.value
query_list = ["Answer the question with the information you have."]
decision_prompt = None
reasoning_result = "(No reasoning result provided yet.)"
tool_calls_string = "(No tool calls provided yet.)"
if research_type == ResearchType.THOUGHTFUL:
if iteration_nr == 1:
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
elif iteration_nr > 1:
# for each iteration past the first one, we need to see whether we
# have enough information to answer the question.
# if we do, we can stop the iteration and return the answer.
# if we do not, we need to continue the iteration.
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_REASONING,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
write_custom_event(
"basic_response",
AgentAnswerPiece(
answer_piece="\n\n\nREASONING TO STOP/CONTINUE:\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
reasoning_prompt = base_reasoning_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
)
reasoning_tokens: list[str] = [""]
reasoning_tokens, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=reasoning_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
# max_tokens=None,
),
)
reasoning_result = cast(str, merge_content(*reasoning_tokens))
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
return OrchestrationUpdate(
tools_used=[DRPath.CLOSER.value],
query_list=[],
iteration_nr=iteration_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
clarification=clarification,
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
)
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
decision_prompt = base_decision_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
reasoning_result=reasoning_result,
)
if remaining_time_budget > 0:
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=decision_prompt,
schema=OrchestratorDecisonsNoPlan,
timeout_override=35,
# max_tokens=2500,
)
next_step = orchestrator_action.next_step
next_tool = next_step.tool
query_list = [q for q in (next_step.questions or [])]
tool_calls_string = create_tool_call_string(next_tool, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
remaining_time_budget -= available_tools[next_tool].cost
else:
if iteration_nr == 1 and not plan_of_record:
# by default, we start a new iteration, but if there is a feedback request,
# we start a new iteration 0 again (set a bit later)
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
base_plan_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.PLAN,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
plan_generation_prompt = base_plan_prompt.build(
question=prompt_question,
chat_history_string=chat_history_string,
)
try:
plan_of_record = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=plan_generation_prompt,
schema=OrchestrationPlan,
timeout_override=25,
# max_tokens=3000,
)
except Exception as e:
logger.error(f"Error in plan generation: {e}")
raise
write_custom_event(
"basic_response",
AgentAnswerPiece(
answer_piece=f"{HIGH_LEVEL_PLAN_PREFIX} {plan_of_record.plan}\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
if not plan_of_record:
raise ValueError(
"Plan information is required for iterative decision making"
)
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
decision_prompt = base_decision_prompt.build(
answer_history_string=answer_history_string,
question_history_string=question_history_string,
question=prompt_question,
iteration_nr=str(iteration_nr),
current_plan_of_record_string=plan_of_record.plan,
chat_history_string=chat_history_string,
remaining_time_budget=str(remaining_time_budget),
gaps=gaps_str,
)
if remaining_time_budget > 0:
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=decision_prompt,
schema=OrchestratorDecisonsNoPlan,
timeout_override=15,
# max_tokens=1500,
)
next_step = orchestrator_action.next_step
next_tool = next_step.tool
query_list = [q for q in (next_step.questions or [])]
reasoning_result = orchestrator_action.reasoning
tool_calls_string = create_tool_call_string(next_tool, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
remaining_time_budget -= available_tools[next_tool].cost
else:
reasoning_result = "Time to wrap up."
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_PURPOSE,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
question=prompt_question,
reasoning_result=reasoning_result,
tool_calls=tool_calls_string,
)
purpose_tokens: list[str] = [""]
# Write short purpose
write_custom_event(
"basic_response",
AgentAnswerPiece(
answer_piece=f"\n\n\nITERATION {iteration_nr}:\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
try:
purpose_tokens, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=orchestration_next_step_purpose_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
# max_tokens=None,
),
)
except Exception as e:
logger.error(f"Error in orchestration next step purpose: {e}")
raise e
purpose = cast(str, merge_content(*purpose_tokens))
return OrchestrationUpdate(
tools_used=[next_tool],
query_list=query_list or [],
iteration_nr=iteration_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
clarification=clarification,
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning=reasoning_result,
purpose=purpose,
)
],
)

View File

@@ -0,0 +1,270 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_citation_format_list
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.db.models import ChatMessage
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def save_iteration(
state: MainState, graph_config: GraphConfig, aggregated_context: AggregatedDRContext
) -> None:
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# update chat message and save iterations
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
)
if not chat_message:
raise ValueError("Chat message with id not found") # should never happen
chat_message.research_type = research_type
chat_message.research_plan = plan_of_record_dict
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
created_at=datetime.now(),
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
parent_question_id=None,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=create_citation_format_list(
[doc for doc in iteration_answer.cited_documents.values()]
),
additional_data=iteration_answer.additional_data,
created_at=datetime.now(),
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
iteration_responses_string = aggregated_context.context
all_cited_documents = aggregated_context.cited_documents
num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=test_info_complete_prompt,
schema=TestInfoCompleteResponse,
timeout_override=40,
# max_tokens=1000,
)
if test_info_complete_json.complete:
pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# Stream out docs - TODO: Improve this with new frontend
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=base_question,
top_sections=all_cited_documents,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.KEYWORD, # unused
final_filters=IndexFilters(access_control_list=None), # unused
recency_bias_multiplier=1.0, # unused
),
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
# Change status from "searching for" to "searched for"
write_custom_event(
"tool_response",
ToolCallFinalResult(
tool_name=SearchTool._NAME,
tool_args={"query": base_question},
tool_result=[], # unused
),
writer,
)
# Generate final answer
write_custom_event(
"basic_response",
AgentAnswerPiece(
answer_piece="\n\n\nFINAL ANSWER:\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
final_answer_prompt = FINAL_ANSWER_PROMPT.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
)
try:
# TODO: fix citations for non-document returning tools (right now, it cites iteration nr)
streamed_output = run_with_timeout(
240,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=final_answer_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
# max_tokens=None,
),
)
final_answer = "".join(streamed_output[0])
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
dispatch_main_answer_stop_info(level=0, writer=writer)
# Log the research agent steps
save_iteration(state, graph_config, aggregated_context)
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,71 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
original_question: str | None = None
chat_history_string: str | None = None
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
iteration_instructions: Annotated[list[IterationInstructions], add] = []
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationUpdate,
AnswerUpdate,
FinalUpdate,
):
pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,254 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.states import AnswerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> AnswerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
)
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=base_search_processing_prompt,
schema=BaseSearchProcessingResponse,
timeout_override=5,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
rewritten_query = search_processing.rewritten_query
implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = [
DocumentSource(source_type)
for source_type in search_processing.specified_source_types
]
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# write_custom_event(
# "basic_response",
# AgentAnswerPiece(
# answer_piece=(
# f"SUB-QUESTION {iteration_nr}.{parallelization_nr} "
# f"(SEARCH): {branch_query}\n\n"
# f"REWRITTEN QUERY: {rewritten_query}\n\n"
# f"PREDICTED SOURCE TYPES: {specified_source_types}\n\n"
# f"PREDICTED TIME FILTER: {implied_time_filter}\n\n"
# " --- \n\n"
# ),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=search_prompt,
schema=SearchAnswer,
timeout_override=40,
# max_tokens=1500,
)
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# write_custom_event(
# "basic_response",
# AgentAnswerPiece(
# answer_piece=f"ANSWERED {iteration_nr}.{parallelization_nr}\n\n",
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return AnswerUpdate(
iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,27 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,152 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.states import AnswerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> AnswerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.llm_path
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_response="(No tool response yet. You need to call the tool to answer the question.)",
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=40,
)
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id == CUSTOM_TOOL_RESPONSE_ID:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# summarise tool result
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke(
tool_summary_prompt, timeout_override=40
).content
).strip()
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return AnswerUpdate(
iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Internet Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,194 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.utils.logger import setup_logger
logger = setup_logger()
def internet_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# write_custom_event(
# "basic_response",
# AgentAnswerPiece(
# answer_piece=(
# f"SUB-QUESTION {iteration_nr}.{parallelization_nr} "
# f"(INTERNET SEARCH): {search_query}\n\n"
# ),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
logger.debug(
f"Search start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
internet_search_tool = cast(InternetSearchTool, is_tool_info.tool_object)
if internet_search_tool.provider is None:
raise ValueError(
"internet_search_tool.provider is not set. This should not happen."
)
# Update search parameters
internet_search_tool.max_chunks = 10
internet_search_tool.provider.num_results = 10
retrieved_docs: list[InferenceSection] = []
for tool_response in internet_search_tool.run(internet_search_query=search_query):
# get retrieved docs to send to the rest of the graph
if tool_response.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
# stream_write_step_answer_explicit(writer, step_nr=1, answer=full_answer)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=search_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=search_prompt,
schema=SearchAnswer,
timeout_override=40,
# max_tokens=3000,
)
logger.debug(
f"LLM/all done for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# write_custom_event(
# "basic_response",
# AgentAnswerPiece(
# answer_piece=f"ANSWERED {iteration_nr}.{parallelization_nr}\n\n",
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
reasoning = ""
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,26 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_1_branch import (
is_branch,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_2_act import (
internet_search,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_is_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", is_branch)
graph.add_node("act", internet_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,96 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
if not state.available_tools:
raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,25 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,42 @@
from operator import add
from typing import Annotated
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str | None = None
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool

View File

@@ -0,0 +1,229 @@
import re
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.context.search.models import InferenceSection
CITATION_PREFIX = "CITE:"
def extract_document_citations(
answer: str, claims: list[str]
) -> tuple[list[int], str, list[str]]:
"""
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
etc., to help with citation deduplication later on.
"""
citations: set[int] = set()
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# This regex matches:
# - \[(\d+)\] for single citations like [1]
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
def _extract_and_replace(match: re.Match[str]) -> str:
numbers = [int(num) for num in match.group(1).split(",")]
citations.update(numbers)
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
new_answer = pattern.sub(_extract_and_replace, answer)
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
return list(citations), new_answer, new_claims
def aggregate_context(
iteration_responses: list[IterationAnswer], include_documents: bool = False
) -> AggregatedDRContext:
"""
Converts the iteration response into a single string with unified citations.
For example,
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
Output:
it 1: the answer is x [1][2].
it 2: blah blah [2][3]
[1]: doc_xyz
[2]: doc_abc
[3]: doc_pqr
"""
# dedupe and merge inference section contents
unrolled_inference_sections: list[InferenceSection] = []
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
for cited_doc in iteration_response.cited_documents.values():
unrolled_inference_sections.append(cited_doc)
cited_doc.center_chunk.score = None # None means maintain order
global_documents = dedup_inference_section_list(unrolled_inference_sections)
global_citations = {
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
}
# build output string
output_strings: list[str] = []
global_iteration_responses: list[IterationAnswer] = []
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# add basic iteration info
output_strings.append(
f"Iteration: {iteration_response.iteration_nr}, "
f"Question {iteration_response.parallelization_nr}"
)
output_strings.append(f"Tool: {iteration_response.tool}")
output_strings.append(f"Question: {iteration_response.question}")
# get answer and claims with global citations
answer_str = iteration_response.answer
claims = iteration_response.claims or []
iteration_citations: list[int] = []
for local_number, cited_doc in iteration_response.cited_documents.items():
global_number = global_citations[cited_doc.center_chunk.document_id]
# translate local citations to global citations
answer_str = answer_str.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
claims = [
claim.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
for claim in claims
]
iteration_citations.append(global_number)
# add answer, claims, and citation info
if answer_str:
output_strings.append(f"Answer: {answer_str}")
if claims:
output_strings.append(
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
or "No claims provided"
)
if not answer_str and not claims:
output_strings.append(
"Retrieved documents: "
+ (
"".join(
f"[{global_number}]"
for global_number in sorted(iteration_citations)
)
or "No documents retrieved"
)
)
output_strings.append("\n---\n")
# save global iteration response
global_iteration_responses.append(
IterationAnswer(
tool=iteration_response.tool,
tool_id=iteration_response.tool_id,
iteration_nr=iteration_response.iteration_nr,
parallelization_nr=iteration_response.parallelization_nr,
question=iteration_response.question,
reasoning=iteration_response.reasoning,
answer=answer_str,
cited_documents={
global_citations[doc.center_chunk.document_id]: doc
for doc in iteration_response.cited_documents.values()
},
background_info=iteration_response.background_info,
claims=claims,
additional_data=iteration_response.additional_data,
)
)
# add document contents if requested
if include_documents:
if global_documents:
output_strings.append("Cited document contents:")
for doc in global_documents:
output_strings.append(
build_document_context(
doc, global_citations[doc.center_chunk.document_id]
)
)
output_strings.append("\n---\n")
return AggregatedDRContext(
context="\n".join(output_strings),
cited_documents=global_documents,
global_iteration_responses=global_iteration_responses,
)
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
"""
Get the chat history (up to max_messages) as a string.
"""
# get past max_messages USER, ASSISTANT message pairs
past_messages = chat_history[-max_messages * 2 :]
return (
"...\n"
if len(chat_history) > len(past_messages)
else ""
"\n".join(
("user" if isinstance(msg, HumanMessage) else "you")
+ f": {str(msg.content).strip()}"
for msg in past_messages
)
)
def get_prompt_question(
question: str, clarification: OrchestrationClarificationInfo | None
) -> str:
if clarification:
clarification_question = clarification.clarification_question
clarification_response = clarification.clarification_response
return (
f"Initial User Question: {question}\n"
f"(Clarification Question: {clarification_question}\n"
f"User Response: {clarification_response})"
)
return question
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
"""
Create a string representation of the tool call.
"""
questions_str = "\n - ".join(query_list)
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# Convert plan string to numbered dict format
if not plan_text:
return {}
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
parts = re.split(r"(\d+[.)])", plan_text)
plan_dict = {}
for i in range(
1, len(parts), 2
): # Skip empty first part, then take number and text pairs
if i + 1 < len(parts):
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
text = parts[i + 1].strip()
if text: # Only add if there's actual content
plan_dict[number] = text
return plan_dict

View File

@@ -6,7 +6,12 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
from onyx.agents.agent_search.kb_search.step_definitions import (
BASIC_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.kb_search.step_definitions import (
KG_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
@@ -95,14 +100,14 @@ def create_minimal_connected_query_graph(
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
def stream_write_step_description(
def stream_write_kg_search_description(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=STEP_DESCRIPTIONS[step_nr].description,
sub_question=KG_SEARCH_STEP_DESCRIPTIONS[step_nr].description,
level=level,
level_question_num=step_nr,
),
@@ -113,10 +118,12 @@ def stream_write_step_description(
sleep(0.2)
def stream_write_step_activities(
def stream_write_kg_search_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
for activity_nr, activity in enumerate(
KG_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
):
write_custom_event(
"subqueries",
SubQueryPiece(
@@ -129,23 +136,25 @@ def stream_write_step_activities(
)
def stream_write_step_activity_explicit(
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
def stream_write_basic_search_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity in STEP_DESCRIPTIONS[step_nr].activities:
for activity_nr, activity in enumerate(
BASIC_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
):
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=query_id,
query_id=activity_nr + 1,
),
writer,
)
def stream_write_step_answer_explicit(
def stream_write_kg_search_answer_explicit(
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
) -> None:
write_custom_event(
@@ -160,8 +169,8 @@ def stream_write_step_answer_explicit(
)
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
def stream_write_kg_search_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in KG_SEARCH_STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
@@ -173,7 +182,7 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
writer,
)
for step_nr in STEP_DESCRIPTIONS.keys():
for step_nr in KG_SEARCH_STEP_DESCRIPTIONS.keys():
write_custom_event(
"stream_finished",
@@ -195,7 +204,40 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
write_custom_event("stream_finished", stop_event, writer)
def stream_close_step_answer(
def stream_write_basic_search_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in BASIC_SEARCH_STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=step_detail.description,
level=level,
level_question_num=step_nr,
),
writer,
)
for step_nr in BASIC_SEARCH_STEP_DESCRIPTIONS:
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=step_nr,
),
writer,
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_kg_search_close_step_answer(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
stop_event = StreamStopInfo(
@@ -207,7 +249,7 @@ def stream_close_step_answer(
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
def stream_write_kg_search_close_steps(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
@@ -355,7 +397,7 @@ def get_near_empty_step_results(
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=STEP_DESCRIPTIONS[step_number].description,
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,

View File

@@ -7,17 +7,23 @@ from langgraph.types import StreamWriter
from pydantic import ValidationError
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_structure,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
@@ -42,7 +48,7 @@ logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
) -> EntityRelationshipExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
@@ -68,17 +74,17 @@ def extract_ert(
user_name = user_email.split("@")[0] or "unknown"
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# Stream structure of substeps out to the UI
stream_write_step_structure(writer)
if state.individual_flow:
# Stream structure of substeps out to the UI
stream_write_kg_search_structure(writer)
# Now specify core activities in the step (step 1)
stream_write_step_activities(writer, _KG_STEP_NR)
stream_write_kg_search_activities(writer, _KG_STEP_NR)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
tenant_id = get_current_tenant_id()
@@ -240,12 +246,13 @@ def extract_ert(
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(writer, step_nr=1, answer=step_answer)
# Finish Step 1
stream_close_step_answer(writer, _KG_STEP_NR)
# Finish Step 1
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
return ERTExtractionUpdate(
return EntityRelationshipExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,

View File

@@ -9,10 +9,14 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
@@ -141,7 +145,7 @@ def analyze(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
@@ -150,7 +154,8 @@ def analyze(
## STEP 2 - stream out goals
stream_write_step_activities(writer, _KG_STEP_NR)
if state.individual_flow:
stream_write_kg_search_activities(writer, _KG_STEP_NR)
# Continue with node
@@ -277,9 +282,12 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
else:
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=step_answer
)
stream_close_step_answer(writer, _KG_STEP_NR)
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
# End node

View File

@@ -8,10 +8,14 @@ from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
@@ -33,8 +37,10 @@ from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
from onyx.prompts.kg_prompts import ENTITY_TABLE_DESCRIPTION
from onyx.prompts.kg_prompts import RELATIONSHIP_TABLE_DESCRIPTION
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_ERROR_FIX_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
from onyx.utils.logger import setup_logger
@@ -122,6 +128,22 @@ def _sql_is_aggregate_query(sql_statement: str) -> bool:
)
def _run_sql(
sql_statement: str, rel_temp_view: str, ent_temp_view: str
) -> list[dict[str, Any]]:
# check sql, just in case
_raise_error_if_sql_fails_problem_test(sql_statement, rel_temp_view, ent_temp_view)
with get_db_readonly_user_session_with_current_tenant() as db_session:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
return [{"count": int(scalar_result)}] if scalar_result is not None else []
# Handle regular row results
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
def _get_source_documents(
sql_statement: str,
llm: LLM,
@@ -189,7 +211,7 @@ def generate_simple_sql(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities_types_str = state.entities_types_str
relationship_types_str = state.relationship_types_str
@@ -199,7 +221,6 @@ def generate_simple_sql(
raise ValueError("kg_doc_temp_view_name is not set")
if state.kg_rel_temp_view_name is None:
raise ValueError("kg_rel_temp_view_name is not set")
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
@@ -207,7 +228,8 @@ def generate_simple_sql(
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
if state.individual_flow:
stream_write_kg_search_activities(writer, _KG_STEP_NR)
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
@@ -270,6 +292,12 @@ def generate_simple_sql(
)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
else:
simple_sql_prompt = (
@@ -289,8 +317,7 @@ def generate_simple_sql(
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
# prepare SQL query generation
# generate initial sql statement
msg = [
HumanMessage(
content=simple_sql_prompt,
@@ -298,7 +325,6 @@ def generate_simple_sql(
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
@@ -336,53 +362,6 @@ def generate_simple_sql(
)
raise e
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
# Correction if needed:
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
# display sql statement with view names replaced by general view names
sql_statement_display = sql_statement.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
@@ -437,51 +416,93 @@ def generate_simple_sql(
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
scalar_result = None
query_results = None
query_results = [] # if no results, will be empty (not None)
query_generation_error = None
# check sql, just in case
_raise_error_if_sql_fails_problem_test(
sql_statement, rel_temp_view, ent_temp_view
)
# run sql
try:
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
if not query_results:
query_generation_error = "SQL query returned no results"
logger.warning(f"{query_generation_error}, retrying...")
except Exception as e:
query_generation_error = str(e)
logger.warning(f"Error executing SQL query: {e}, retrying...")
# fix sql and try one more time if sql query didn't work out
# if the result is still empty after this, the kg probably doesn't have the answer,
# so we update the strategy to simple and address this in the answer generation
if query_generation_error is not None:
sql_fix_prompt = (
SIMPLE_SQL_ERROR_FIX_PROMPT.replace(
"---table_description---",
(
ENTITY_TABLE_DESCRIPTION
if state.query_type
== KGRelationshipDetection.NO_RELATIONSHIPS.value
else RELATIONSHIP_TABLE_DESCRIPTION
),
)
.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace("---sql_statement---", sql_statement)
.replace("---error_message---", query_generation_error)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
msg = [HumanMessage(content=sql_fix_prompt)]
primary_llm = graph_config.tooling.primary_llm
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result)}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
sql_statement = sql_statement.replace(
"relationship_table", rel_temp_view
)
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
reasoning = (
cleaned_response.split("<reasoning>")[1]
.strip()
.split("</reasoning>")[0]
)
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
except Exception as e:
logger.error(f"Error executing SQL query even after retry: {e}")
# TODO: raise error on frontend
logger.error(f"Error executing SQL query: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
raise
source_document_results = None
if source_documents_sql is not None and source_documents_sql != sql_statement:
# check source document sql, just in case
_raise_error_if_sql_fails_problem_test(
source_documents_sql, rel_temp_view, ent_temp_view
)
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(source_documents_sql))
rows = result.fetchall()
@@ -491,28 +512,16 @@ def generate_simple_sql(
for source_document_result in query_source_document_results
]
except Exception as e:
# TODO: raise error on frontend
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
# TODO: raise warning on frontend
logger.error(f"Error executing Individualized SQL query: {e}")
elif state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
# source documents should be returned for entity queries
source_document_results = [
x["source_document"] for x in query_results if "source_document" in x
]
else:
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
# source documents should be returned for entity queries
source_document_results = [
x["source_document"]
for x in query_results
if "source_document" in x
]
else:
source_document_results = None
source_document_results = None
drop_views(
allowed_docs_view_name=doc_temp_view,
@@ -528,21 +537,25 @@ def generate_simple_sql(
main_sql_statement = sql_statement
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if reasoning and state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=reasoning
)
if sql_statement_display:
stream_write_step_answer_explicit(
if sql_statement_display and state.individual_flow:
stream_write_kg_search_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {sql_statement_display}",
)
stream_close_step_answer(writer, _KG_STEP_NR)
if state.individual_flow:
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
# Update path if too many results are retrieved
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
# Update path if too many, or no results were retrieved from sql
if main_sql_statement and (
not query_results or len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS
):
updated_strategy = KGAnswerStrategy.SIMPLE
else:
updated_strategy = None

View File

@@ -34,7 +34,7 @@ def construct_deep_search_filters(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
@@ -155,7 +155,11 @@ def construct_deep_search_filters(
if div_con_structure:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
# entity_type is guaranteed to have grounded_source_name
if (
cast(str, entity_type.grounded_source_name).lower()
in div_con_structure[0].lower()
):
source_division = True
break

View File

@@ -98,16 +98,17 @@ def process_individual_deep_search(
kg_relationship_filters = None
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
if state.individual_flow:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
logger.debug(

View File

@@ -7,9 +7,11 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
@@ -49,7 +51,7 @@ def filtered_search(
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
if not search_tool:
raise ValueError("search_tool is not provided")
@@ -72,17 +74,18 @@ def filtered_search(
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
if state.individual_flow:
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
retrieved_docs = cast(
list[InferenceSection],
@@ -165,11 +168,12 @@ def filtered_search(
step_answer = "Filtered search is complete."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,

View File

@@ -5,9 +5,11 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
@@ -41,11 +43,12 @@ def consolidate_individual_deep_search(
step_answer = "All research is complete. Consolidating results..."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,

View File

@@ -4,9 +4,11 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
@@ -66,28 +68,26 @@ def process_kg_only_answers(
# we use this stream write explicitly
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
query_results_list = []
if state.individual_flow:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
if query_results:
for query_result in query_results:
query_results_list.append(
str(query_result).replace("::", ":: ").capitalize()
)
query_results_data_str = "\n".join(
str(query_result).replace("::", ":: ").capitalize()
for query_result in query_results
)
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
logger.warning("No query results were found")
query_results_data_str = "(No query results were found)"
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
@@ -99,9 +99,12 @@ def process_kg_only_answers(
"No further research is needed, the answer is derived from the knowledge graph."
)
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=step_answer
)
stream_close_step_answer(writer, _KG_STEP_NR)
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,

View File

@@ -7,14 +7,17 @@ from langgraph.types import StreamWriter
from onyx.access.access import get_acl_for_user
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_close_steps,
)
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -42,7 +45,7 @@ logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
) -> FinalAnswerUpdate:
"""
LangGraph node to start the agentic search process.
"""
@@ -50,7 +53,9 @@ def generate_answer(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
final_answer: str | None = None
user = (
graph_config.tooling.search_tool.user
@@ -69,7 +74,8 @@ def generate_answer(
# DECLARE STEPS DONE
stream_write_close_steps(writer)
if state.individual_flow:
stream_write_kg_search_close_steps(writer)
## MAIN ANSWER
@@ -128,16 +134,17 @@ def generate_answer(
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if state.individual_flow:
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
# continue with the answer generation
@@ -206,24 +213,40 @@ def generate_answer(
)
]
try:
run_with_timeout(
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=graph_config.tooling.fast_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
if state.individual_flow:
stream_results: tuple[list[str], list[float]] = run_with_timeout(
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
),
)
final_answer = "".join(stream_results[0])
else:
final_answer = get_answer_from_llm(
llm=graph_config.tooling.primary_llm,
prompt=output_format_prompt,
stream=False,
json_string_flag=False,
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
),
)
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
return MainOutput(
return FinalAnswerUpdate(
final_answer=final_answer,
retrieved_documents=answer_generation_documents.context_documents,
step_results=[],
remarks=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -48,6 +48,8 @@ def log_data(
)
return MainOutput(
final_answer=state.final_answer,
retrieved_documents=state.retrieved_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -120,7 +120,7 @@ class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
class EntityRelationshipExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
@@ -144,7 +144,13 @@ class ResearchObjectUpdate(LoggerUpdate):
## Graph Input State
class MainInput(CoreState):
pass
question: str
individual_flow: bool = True # used for UI display purposes
class FinalAnswerUpdate(LoggerUpdate):
final_answer: str | None = None
retrieved_documents: list[InferenceSection] | None = None
## Graph State
@@ -154,7 +160,7 @@ class MainState(
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
EntityRelationshipExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
@@ -162,6 +168,7 @@ class MainState(
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
FinalAnswerUpdate,
):
pass
@@ -169,6 +176,8 @@ class MainState(
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
retrieved_documents: list[InferenceSection] | None
class ResearchObjectInput(LoggerUpdate):
@@ -179,3 +188,4 @@ class ResearchObjectInput(LoggerUpdate):
source_division: bool | None
source_entity_filters: list[str] | None
segment_type: str
individual_flow: bool = True # used for UI display purposes

View File

@@ -1,6 +1,6 @@
from onyx.agents.agent_search.kb_search.models import KGSteps
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
@@ -27,3 +27,7 @@ STEP_DESCRIPTIONS: dict[int, KGSteps] = {
description="Conducting further research on source documents...", activities=[]
),
}
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(description="Conducting a standard search...", activities=[]),
}

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
@@ -72,6 +73,7 @@ class GraphSearchConfig(BaseModel):
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
research_type: ResearchType = ResearchType.THOUGHTFUL
class GraphConfig(BaseModel):

View File

@@ -271,7 +271,10 @@ def choose_tool(
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
)
).ai_message_chunk
if tool_message is None:
raise ValueError("No tool message emitted by LLM")
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:

View File

@@ -4,6 +4,7 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
@@ -21,6 +22,7 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -62,7 +64,9 @@ def basic_use_tool_response(
for section in dedupe_documents(search_response_summary.top_sections)[0]
]
new_tool_call_chunk = AIMessageChunk(content="")
new_tool_call_chunk = BasicSearchProcessedStreamResults(
ai_message_chunk=AIMessageChunk(content=""), full_answer=None
)
if not agent_config.behavior.skip_gen_ai_answer_generation:
stream = llm.stream(
prompt=new_prompt_builder.build(),
@@ -80,4 +84,9 @@ def basic_use_tool_response(
displayed_search_results=initial_search_results or final_search_results,
)
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
return BasicOutput(
tool_call_chunk=new_tool_call_chunk.ai_message_chunk,
full_answer=new_tool_call_chunk.full_answer,
cited_references=new_tool_call_chunk.cited_references,
retrieved_documents=new_tool_call_chunk.retrieved_documents,
)

View File

@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
@@ -41,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
GraphInput = BasicInput | MainInput | DCMainInput | KBMainInput | DRMainInput
_COMPILED_GRAPH: CompiledStateGraph | None = None
@@ -90,7 +94,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
graph_input: GraphInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -104,7 +108,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput | KBMainInput,
input: GraphInput,
) -> AnswerStream:
for event in manage_sync_streaming(
@@ -154,7 +158,24 @@ def run_kb_graph(
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
input = KBMainInput(
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
)
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
)
yield from run_graph(compiled_graph, config, input)
def run_dr_graph(
config: GraphConfig,
) -> AnswerStream:
graph = dr_graph_builder()
compiled_graph = graph.compile()
input = DRMainInput(log_messages=[])
yield ToolCallKickoff(
tool_name="agent_search_0",

View File

@@ -1,12 +1,28 @@
import re
from datetime import datetime
from typing import cast
from typing import Literal
from typing import Type
from typing import TypeVar
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.utils.threadpool_concurrency import run_with_timeout
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# match ```json{...}``` or ```{...}```
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
def stream_llm_answer(
@@ -66,3 +82,111 @@ def stream_llm_answer(
response.append(content)
return response, dispatch_timings
def invoke_llm_json(
llm: LLM,
prompt: LanguageModelInput,
schema: Type[SchemaType],
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> SchemaType:
"""
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
# check if the model supports response_format: json_schema
supports_json = "response_format" in (
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
or []
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
response_content = str(
llm.invoke(
prompt,
tools=tools,
tool_choice=tool_choice,
timeout_override=timeout_override,
max_tokens=max_tokens,
**cast(
dict, {"structured_response_format": schema} if supports_json else {}
),
).content
)
if not supports_json:
# remove newlines as they often lead to json decoding errors
response_content = response_content.replace("\n", " ")
# hope the prompt is structured in a way a json is outputted...
json_block_match = JSON_PATTERN.search(response_content)
if json_block_match:
response_content = json_block_match.group(1)
else:
first_bracket = response_content.find("{")
last_bracket = response_content.rfind("}")
response_content = response_content[first_bracket : last_bracket + 1]
return schema.model_validate_json(response_content)
def get_answer_from_llm(
llm: LLM,
prompt: str,
timeout: int = 25,
timeout_override: int = 5,
max_tokens: int = 500,
stream: bool = False,
writer: StreamWriter = lambda _: None,
agent_answer_level: int = 0,
agent_answer_question_num: int = 0,
agent_answer_type: Literal[
"agent_sub_answer", "agent_level_answer"
] = "agent_level_answer",
json_string_flag: bool = False,
) -> str:
msg = [
HumanMessage(
content=prompt,
)
]
if stream:
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
stream_response, _ = run_with_timeout(
timeout,
lambda: stream_llm_answer(
llm=llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=agent_answer_level,
agent_answer_question_num=agent_answer_question_num,
agent_answer_type=agent_answer_type,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
content = "".join(stream_response)
else:
llm_response = run_with_timeout(
timeout,
llm.invoke,
prompt=msg,
timeout_override=timeout_override,
max_tokens=max_tokens,
)
content = str(llm_response.content)
cleaned_response = content
if json_string_flag:
cleaned_response = (
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
return cleaned_response

View File

@@ -0,0 +1,39 @@
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.context.search.models import InferenceSection
def create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
def create_question_prompt(
system_prompt: str | None, human_prompt: str
) -> list[BaseMessage]:
return [
SystemMessage(content=system_prompt or ""),
HumanMessage(content=human_prompt),
]

View File

@@ -4,6 +4,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
@@ -12,7 +13,7 @@ from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.agents.agent_search.run_graph import run_dr_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
@@ -27,6 +28,7 @@ from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.models import RerankingDetails
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import Persona
@@ -124,6 +126,9 @@ class Answer:
allow_agent_reranking=allow_agent_reranking,
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
kg_config_settings=get_kg_config_settings(),
research_type=(
ResearchType.DEEP if use_agentic_search else ResearchType.THOUGHTFUL
),
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
@@ -138,12 +143,16 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search and (
# TODO: add toggle in UI with customizable TimeBudget
if (
self.graph_config.inputs.persona
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
and self.graph_config.inputs.persona.name.startswith("KG Beta")
and self.graph_config.inputs.persona.name.startswith(
TMP_DRALPHA_PERSONA_NAME
)
):
run_langgraph = run_kb_graph
run_langgraph = run_dr_graph
elif self.graph_config.behavior.use_agentic_search:
run_langgraph = run_agent_search_graph
elif (

View File

@@ -19,6 +19,7 @@ from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
@@ -401,7 +402,7 @@ def process_kg_commands(
) -> None:
# Temporarily, until we have a draft UI for the KG Operations/Management
# TODO: move to api endpoint once we get frontend
if not persona_name.startswith("KG Beta"):
if not persona_name.startswith(TMP_DRALPHA_PERSONA_NAME):
return
kg_config_settings = get_kg_config_settings()

View File

@@ -54,6 +54,7 @@ from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
@@ -845,6 +846,18 @@ def stream_chat_message_objects(
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage:
is_kg_beta = parent_message.chat_session.persona.name.startswith(
TMP_DRALPHA_PERSONA_NAME
)
is_basic_search = tool_call and tool_call.tool_name == SearchTool._NAME
is_agentic_overwrite = new_msg_req.use_agentic_search and not (
is_kg_beta and is_basic_search
)
if is_kg_beta:
is_agentic_overwrite = False
return create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=(
@@ -867,7 +880,7 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
is_agentic=is_agentic_overwrite,
)
partial_response = create_response

View File

@@ -3,6 +3,7 @@ import socket
from enum import auto
from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
@@ -138,6 +139,8 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
class DocumentSource(str, Enum):
# Special case, document passed in via Onyx APIs without specifying a source type
@@ -512,3 +515,57 @@ else:
class OnyxCallTypes(str, Enum):
FIREFLIES = "FIREFLIES"
GONG = "GONG"
# TODO: this should be stored likely in database
DocumentSourceDescription: dict[DocumentSource, str] = {
# Special case, document passed in via Onyx APIs without specifying a source type
DocumentSource.INGESTION_API: "ingestion_api",
DocumentSource.SLACK: "slack channels",
DocumentSource.WEB: "web pages",
DocumentSource.GOOGLE_DRIVE: "google drive documents (docs, sheets, etc.)",
DocumentSource.GMAIL: "email messages",
DocumentSource.REQUESTTRACKER: "requesttracker",
DocumentSource.GITHUB: "github data",
DocumentSource.GITBOOK: "gitbook data",
DocumentSource.GITLAB: "gitlab data",
DocumentSource.GURU: "guru data",
DocumentSource.BOOKSTACK: "bookstack data",
DocumentSource.CONFLUENCE: "confluence data (pages, spaces, etc.)",
DocumentSource.JIRA: "jira data (issues, tickets, projects, etc.)",
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",
DocumentSource.ZULIP: "zulip data",
DocumentSource.LINEAR: "linear data - project management tool, including tickets etc.",
DocumentSource.HUBSPOT: "hubspot data - CRM and marketing automation data",
DocumentSource.DOCUMENT360: "document360 data",
DocumentSource.GONG: "gong - call transcripts",
DocumentSource.GOOGLE_SITES: "google_sites - websites",
DocumentSource.ZENDESK: "zendesk - customer support data",
DocumentSource.LOOPIO: "loopio - rfp data",
DocumentSource.DROPBOX: "dropbox - files",
DocumentSource.SHAREPOINT: "sharepoint - files",
DocumentSource.TEAMS: "teams - chat and collaboration",
DocumentSource.SALESFORCE: "salesforce - CRM data",
DocumentSource.DISCOURSE: "discourse - discussion forums",
DocumentSource.AXERO: "axero - employee engagement data",
DocumentSource.CLICKUP: "clickup - project management tool",
DocumentSource.MEDIAWIKI: "mediawiki - wiki data",
DocumentSource.WIKIPEDIA: "wikipedia - encyclopedia data",
DocumentSource.ASANA: "asana",
DocumentSource.S3: "s3",
DocumentSource.R2: "r2",
DocumentSource.GOOGLE_CLOUD_STORAGE: "google_cloud_storage - cloud storage",
DocumentSource.OCI_STORAGE: "oci_storage - cloud storage",
DocumentSource.XENFORO: "xenforo - forum data",
DocumentSource.DISCORD: "discord - chat and collaboration",
DocumentSource.FRESHDESK: "freshdesk - customer support data",
DocumentSource.FIREFLIES: "fireflies - call transcripts",
DocumentSource.EGNYTE: "egnyte - files",
DocumentSource.AIRTABLE: "airtable - database",
DocumentSource.HIGHSPOT: "highspot - CRM data",
DocumentSource.IMAP: "imap - email data",
}

View File

@@ -140,3 +140,5 @@ KG_MAX_SEARCH_DOCUMENTS: int = int(os.environ.get("KG_MAX_SEARCH_DOCUMENTS", "15
KG_MAX_DECOMPOSITION_SEGMENTS: int = int(
os.environ.get("KG_MAX_DECOMPOSITION_SEGMENTS", "10")
)
KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
to answer questions"

View File

View File

@@ -1,7 +1,6 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from typing import Tuple
from uuid import UUID
@@ -23,12 +22,12 @@ from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetr
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.utils import create_citation_format_list
from onyx.auth.schemas import UserRole
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDocs
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc as ServerSearchDoc
@@ -1111,27 +1110,6 @@ def log_agent_sub_question_results(
primary_message_id: int | None,
sub_question_answer_results: list[SubQuestionAnswerResults],
) -> None:
def _create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
now = datetime.now()
@@ -1141,7 +1119,7 @@ def log_agent_sub_question_results(
]
sub_question = sub_question_answer_result.question
sub_answer = sub_question_answer_result.answer
sub_document_results = _create_citation_format_list(
sub_document_results = create_citation_format_list(
sub_question_answer_result.context_documents
)

View File

@@ -82,6 +82,7 @@ from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
from onyx.agents.agent_search.dr.enums import ResearchType
logger = setup_logger()
@@ -677,8 +678,8 @@ class KGEntityType(Base):
DateTime(timezone=True), server_default=func.now()
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
grounded_source_name: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=False
)
entity_values: Mapped[list[str]] = mapped_column(
@@ -2145,6 +2146,11 @@ class ChatMessage(Base):
back_populates="chat_messages",
)
research_type: Mapped[ResearchType] = mapped_column(
Enum(ResearchType, native_enum=False), nullable=True
)
research_plan: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
class ChatFolder(Base):
"""For organizing chat sessions"""
@@ -3343,3 +3349,40 @@ class TenantAnonymousUserPath(Base):
anonymous_user_path: Mapped[str] = mapped_column(
String, nullable=False, unique=True
)
class ResearchAgentIteration(Base):
__tablename__ = "research_agent_iteration"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_question_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id", ondelete="CASCADE")
)
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
purpose: Mapped[str] = mapped_column(String, nullable=True)
reasoning: Mapped[str] = mapped_column(String, nullable=True)
class ResearchAgentIterationSubStep(Base):
__tablename__ = "research_agent_iteration_sub_step"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_question_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id", ondelete="CASCADE")
)
parent_question_id: Mapped[int | None] = mapped_column(
ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
)
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
iteration_sub_step_nr: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
sub_step_instructions: Mapped[str] = mapped_column(String, nullable=True)
sub_step_tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), nullable=True)
reasoning: Mapped[str] = mapped_column(String, nullable=True)
sub_answer: Mapped[str] = mapped_column(String, nullable=True)
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
claims: Mapped[list[str]] = mapped_column(postgresql.JSONB(), nullable=True)
additional_data: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)

View File

@@ -16,7 +16,8 @@ from onyx.db.models import User
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.tools.built_in_tools import get_search_tool
from onyx.tools.built_in_tools import get_builtin_tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -49,9 +50,7 @@ def create_slack_channel_persona(
) -> Persona:
"""NOTE: does not commit changes"""
search_tool = get_search_tool(db_session)
if search_tool is None:
raise ValueError("Search tool not found")
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)

View File

@@ -47,15 +47,15 @@ logger = setup_logger()
def _get_classification_extraction_instructions() -> (
dict[str, dict[str, KGEntityTypeInstructions]]
dict[str | None, dict[str, KGEntityTypeInstructions]]
):
"""
Prepare the classification instructions for the given source.
"""
classification_instructions_dict: dict[str, dict[str, KGEntityTypeInstructions]] = (
{}
)
classification_instructions_dict: dict[
str | None, dict[str, KGEntityTypeInstructions]
] = {}
with get_session_with_current_tenant() as db_session:
entity_types = get_entity_types(db_session, active=True)

View File

@@ -32,9 +32,7 @@ def format_entity_id_for_models(entity_id_name: str) -> str:
separator = entity_type = ""
formatted_entity_type = entity_type.strip().upper()
formatted_entity_name = (
entity_name.strip().replace('"', "").replace("'", "").title()
)
formatted_entity_name = entity_name.strip().replace('"', "").replace("'", "")
return f"{formatted_entity_type}{separator}{formatted_entity_name}"

File diff suppressed because it is too large Load Diff

View File

@@ -669,8 +669,8 @@ that should be used to analyze each object/each source (or 'the object' that fit
}}
Do not include any other text or explanations.
"""
SOURCE_DETECTION_PROMPT = f"""
You are an expert in generating, understanding and analyzing SQL statements.
@@ -773,11 +773,29 @@ Please structure your answer using <reasoning>, </reasoning>,<sql>, </sql> start
""".strip()
SIMPLE_SQL_PROMPT = f"""
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
between TWO ENTITIES. The table has the following structure:
ENTITY_TABLE_DESCRIPTION = f"""\
- Table name: entity_table
- Columns:
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
- entity_type (str): the type of the entity [example: ACCOUNT].
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
- source_document (str): the id of the document that contains the entity. Note that the combination of \
id_name and source_document IS UNIQUE!
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
{SEPARATOR_LINE}
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
"""
RELATIONSHIP_TABLE_DESCRIPTION = f"""\
- Table name: relationship_table
- Columns:
- relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \
@@ -803,17 +821,27 @@ id_name and source_document IS UNIQUE!
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided.
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>:
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>.
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by '::<entity_name>' \
in the relationship id as shown above.
{SEPARATOR_LINE}
---relationship_types---
{SEPARATOR_LINE}
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by ':<entity_name>' in the \
relationship id as shown above..
"""
SIMPLE_SQL_PROMPT = f"""
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
between TWO ENTITIES. The table has the following structure:
{SEPARATOR_LINE}
{RELATIONSHIP_TABLE_DESCRIPTION}
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
@@ -936,7 +964,7 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
<sql>[the SQL statement that you generate to satisfy the task]</sql>
""".strip()
# TODO: remove following before merging after enough testing
SIMPLE_SQL_CORRECTION_PROMPT = f"""
You are an expert in reviewing and fixing SQL statements.
@@ -949,7 +977,7 @@ Guidance:
SELECT statement as well! And it needs to be in the EXACT FORM! So if a \
conversion took place, make sure to include the conversion in the SELECT and the ORDER BY clause!
- never should 'source_document' be in the SELECT clause! Remove if present!
- if there are joins, they must be on entities, never sour ce documents
- if there are joins, they must be on entities, never source documents
- if there are joins, consider the possibility that the second entity does not exist for all examples.\
Therefore consider using LEFT joins (or RIGHT joins) as appropriate.
@@ -969,26 +997,7 @@ You are an expert in generating a SQL statement that only uses ONE TABLE that ca
and their attributes and other data. The table has the following structure:
{SEPARATOR_LINE}
- Table name: entity_table
- Columns:
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
- entity_type (str): the type of the entity [example: ACCOUNT].
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
- source_document (str): the id of the document that contains the entity. Note that the combination of \
id_name and source_document IS UNIQUE!
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
{SEPARATOR_LINE}
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
{ENTITY_TABLE_DESCRIPTION}
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
@@ -1077,33 +1086,55 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
<sql>[the SQL statement that you generate to satisfy the task]</sql>
""".strip()
SIMPLE_SQL_ERROR_FIX_PROMPT = f"""
You are an expert at fixing SQL statements. You will be provided with a SQL statement that aims to address \
a question, but it contains an error. Your task is to fix the SQL statement, based on the error message.
SQL_AGGREGATION_REMOVAL_PROMPT = f"""
You are a SQL expert. You were provided with a SQL statement that returns an aggregation, and you are \
tasked to show the underlying objects that were aggregated. For this you need to remove the aggregate functions \
from the SQL statement in the correct way.
Here is the description of the table that the SQL statement is supposed to use:
---table_description---
Additional rules:
- if you see a 'select count(*)', you should NOT convert \
that to 'select *...', but rather return the corresponding id_name, entity_type_id_name, name, and document_id. \
As in: 'select <table, if necessary>.id_name, <table, if necessary>.entity_type_id_name, \
<table, if necessary>.name, <table, if necessary>.document_id ...'. \
The id_name is always the primary index, and those should be returned, along with the type (entity_type_id_name), \
the name (name) of the objects, and the document_id (document_id) of the object.
- Add a limit of 30 to the select statement.
- Don't change anything else.
- The final select statement needs obviously to be a valid SQL statement.
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
---question---
{SEPARATOR_LINE}
Here is the SQL statement you are supposed to remove the aggregate functions from:
Here is the SQL statement that you should fix:
{SEPARATOR_LINE}
---sql_statement---
{SEPARATOR_LINE}
Here is the error message that was returned:
{SEPARATOR_LINE}
---error_message---
{SEPARATOR_LINE}
Note that in the case the error states the sql statement did not return any results, it is possible that the \
sql statement is correct, but the question is not addressable with the information in the knowledge graph. \
If you are absolutely certain that is the case, you may return the original sql statement.
Here are a couple common errors that you may encounter:
- source_document is in the SELECT clause -> remove it
- columns used in ORDER BY must also appear in the SELECT DISTINCT clause
- consider carefully the type of the columns you are using, especially for attributes. You may need to cast them
- dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \
So please use that format, particularly if you use data comparisons (>, <, ...)
- attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \
"attributes ->> '<attribute>' = '<attribute value>'" (or "attributes ? '<attribute>'" to check for existence).
- if you are using joins and the sql returned no joins, make sure you are using the appropriate join type (LEFT, RIGHT, etc.) \
it is possible that the second entity does not exist for all examples.
- (ignore if using entity_table) if using the relationship_table and the sql returned no results, make sure you are \
selecting the correct column! Use the available relationship types to determine whether to use the source or target entity.
APPROACH:
Please think through this step by step. Please also bear in mind that the sql statement is written in postgres syntax.
Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---.
Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> start and end tags as in:
<reasoning>[your short step-by step thinking]</reasoning>
<sql>[the SQL statement without the aggregate functions]</sql>
""".strip()
<reasoning>[think through the logic but do so extremely briefly! Not more than 3-4 sentences.]</reasoning>
<sql>[the SQL statement that you generate to satisfy the task]</sql>
"""
SEARCH_FILTER_CONSTRUCTION_PROMPT = f"""

View File

@@ -0,0 +1,43 @@
import re
class PromptTemplate:
"""
A class for building prompt templates with placeholders.
Useful when building templates with json schemas, as {} will not work with f-strings.
Unlike string.replace, this class will raise an error if the fields are missing.
"""
DEFAULT_PATTERN = r"---([a-zA-Z0-9_]+)---"
def __init__(self, template: str, pattern: str = DEFAULT_PATTERN):
self._pattern_str = pattern
self._pattern = re.compile(pattern)
self._template = template
self._fields: set[str] = set(self._pattern.findall(template))
def build(self, **kwargs: str) -> str:
"""
Build the prompt template with the given fields.
Will raise an error if the fields are missing.
Will ignore fields that are not in the template.
"""
missing = self._fields - set(kwargs.keys())
if missing:
raise ValueError(f"Missing required fields: {missing}.")
return self._replace_fields(kwargs)
def partial_build(self, **kwargs: str) -> "PromptTemplate":
"""
Returns another PromptTemplate with the given fields replaced.
Will ignore fields that are not in the template.
"""
new_template = self._replace_fields(kwargs)
return PromptTemplate(new_template, self._pattern_str)
def _replace_fields(self, field_vals: dict[str, str]) -> str:
def repl(match: re.Match) -> str:
key = match.group(1)
return field_vals.get(key, match.group(0))
return self._pattern.sub(repl, self._template)

View File

@@ -3,6 +3,8 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session
from onyx.db.entities import get_entity_stats_by_grounded_source_name
@@ -31,12 +33,13 @@ from onyx.server.kg.models import KGConfig
from onyx.server.kg.models import KGConfig as KGConfigAPIModel
from onyx.server.kg.models import SourceAndEntityTypeView
from onyx.server.kg.models import SourceStatistics
from onyx.tools.built_in_tools import get_search_tool
from onyx.tools.built_in_tools import get_builtin_tool
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
_KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
to answer questions"
admin_router = APIRouter(prefix="/admin/kg")
@@ -95,12 +98,9 @@ def enable_or_disable_kg(
enable_kg(enable_req=req)
populate_missing_default_entity_types__commit(db_session=db_session)
# Create or restore KG Beta persona
# Get the search tool
search_tool = get_search_tool(db_session=db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")
# Get the search and knowledge graph tools
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
kg_tool = get_builtin_tool(db_session=db_session, tool_type=KnowledgeGraphTool)
# Check if we have a previously created persona
kg_config_settings = get_kg_config_settings()
@@ -132,8 +132,8 @@ def enable_or_disable_kg(
is_public = len(user_ids) == 0
persona_request = PersonaUpsertRequest(
name="KG Beta",
description=_KG_BETA_ASSISTANT_DESCRIPTION,
name=TMP_DRALPHA_PERSONA_NAME,
description=KG_BETA_ASSISTANT_DESCRIPTION,
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
datetime_aware=False,
@@ -145,7 +145,7 @@ def enable_or_disable_kg(
recency_bias=RecencyBiasSetting.NO_DECAY,
prompt_ids=[0],
document_set_ids=[],
tool_ids=[search_tool.id],
tool_ids=[search_tool.id, kg_tool.id],
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -17,6 +17,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.internet_search.providers import (
get_available_providers,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
@@ -63,6 +66,15 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
if (bool(get_available_providers()))
else []
),
InCodeToolInfo(
cls=KnowledgeGraphTool,
description=(
"The Knowledge Graph Search Action allows the assistant to search the knowledge graph for information."
"This tool should only be used by the Deep Research Agent, not via tool calling."
),
in_code_tool_id=KnowledgeGraphTool.__name__,
display_name=KnowledgeGraphTool._DISPLAY_NAME,
),
]
@@ -106,27 +118,37 @@ def load_builtin_tools(db_session: Session) -> None:
logger.notice("All built-in tools are loaded/verified.")
def get_search_tool(db_session: Session) -> ToolDBModel | None:
def get_builtin_tool(
db_session: Session,
tool_type: Type[
SearchTool | ImageGenerationTool | InternetSearchTool | KnowledgeGraphTool
],
) -> ToolDBModel:
"""
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
Retrieves a built-in tool from the database based on the tool type.
"""
search_tool_id = next(
tool_id = next(
(
tool["in_code_tool_id"]
for tool in BUILT_IN_TOOLS
if tool["cls"].__name__ == SearchTool.__name__
if tool["cls"].__name__ == tool_type.__name__
),
None,
)
if not search_tool_id:
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
if not tool_id:
raise RuntimeError(
f"Tool type {tool_type.__name__} not found in the BUILT_IN_TOOLS list."
)
search_tool = db_session.execute(
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
db_tool = db_session.execute(
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == tool_id)
).scalar_one_or_none()
return search_tool
if not db_tool:
raise RuntimeError(f"Tool type {tool_type.__name__} not found in the database.")
return db_tool
def auto_add_search_tool_to_personas(db_session: Session) -> None:
@@ -136,10 +158,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None:
Persona objects that were created before the concept of Tools were added.
"""
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
search_tool = get_search_tool(db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
# Fetch all Personas that need the SearchTool added
personas_to_update = (

View File

@@ -20,6 +20,11 @@ OVERRIDE_T = TypeVar("OVERRIDE_T")
class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def id(self) -> int:
raise NotImplementedError
@property
@abc.abstractmethod
def name(self) -> str:

View File

@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
@@ -41,6 +42,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import compute_all_tool_tokens
from onyx.tools.utils import explicit_tool_calling_supported
@@ -265,6 +269,14 @@ def construct_tools(
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
)
# Handle KG Tool
elif tool_cls.__name__ == KnowledgeGraphTool.__name__:
if persona.name != TMP_DRALPHA_PERSONA_NAME:
raise ValueError(
f"Knowledge Graph Tool should only be used by the '{TMP_DRALPHA_PERSONA_NAME}' Agent."
)
tool_dict[db_tool_model.id] = [KnowledgeGraphTool()]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:

View File

@@ -17,6 +17,8 @@ from requests import JSONDecodeError
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tools
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
@@ -77,6 +79,7 @@ class CustomToolCallSummary(BaseModel):
class CustomTool(BaseTool):
def __init__(
self,
id: int,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[HeaderItemDict] | None = None,
@@ -86,6 +89,7 @@ class CustomTool(BaseTool):
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._user_oauth_token = user_oauth_token
self._id = id
self._name = self._method_spec.name
self._description = self._method_spec.summary
@@ -107,6 +111,10 @@ class CustomTool(BaseTool):
if self._user_oauth_token:
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._name
@@ -382,11 +390,27 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
openapi_schema_str = json.dumps(openapi_schema)
with get_session_with_current_tenant() as temp_db_session:
tools = get_tools(temp_db_session)
tool_id: int | None = None
for tool in tools:
if tool.openapi_schema and (
json.dumps(tool.openapi_schema) == openapi_schema_str
):
tool_id = tool.id
break
if not tool_id:
raise ValueError(f"Tool with openapi_schema {openapi_schema_str} not found")
return [
CustomTool(
method_spec,
url,
custom_headers,
id=tool_id,
method_spec=method_spec,
base_url=url,
custom_headers=custom_headers,
user_oauth_token=user_oauth_token,
)
for method_spec in method_specs

View File

@@ -13,6 +13,8 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Tool as ToolDBModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
@@ -112,6 +114,22 @@ class ImageGenerationTool(Tool[None]):
self.additional_headers = additional_headers
self.output_format = output_format
with get_session_with_current_tenant() as db_session:
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == ImageGenerationTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Image Generation tool not found. This should never happen."
)
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME

View File

@@ -29,6 +29,7 @@ from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.db.search_settings import get_current_search_settings
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
@@ -143,8 +144,23 @@ class InternetSearchTool(Tool[None]):
)
)
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == InternetSearchTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Internet Search tool not found. This should never happen."
)
self._id = tool_id
"""For explicit tool calling"""
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME

View File

@@ -0,0 +1,118 @@
from collections.abc import Generator
from typing import Any
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Tool as ToolDBModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
QUERY_FIELD = "query"
class KnowledgeGraphTool(Tool[None]):
_NAME = "run_kg_search"
_DESCRIPTION = "Search the knowledge graph for information. Never call this tool."
_DISPLAY_NAME = "Knowledge Graph Search"
def __init__(self) -> None:
with get_session_with_current_tenant() as db_session:
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == KnowledgeGraphTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Knowledge Graph tool not found. This should never happen."
)
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
QUERY_FIELD: {
"type": "string",
"description": "What to search for",
},
},
"required": [QUERY_FIELD],
},
},
}
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)

View File

@@ -34,6 +34,7 @@ from onyx.context.search.models import UserFileFilters
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.db.models import User
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -162,6 +163,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
)
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == SearchTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError("Search tool not found. This should never happen.")
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME