mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-08 16:32:43 +00:00
Compare commits
6 Commits
v3.1.2
...
hackathon_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2de6e4c5a | ||
|
|
728f727976 | ||
|
|
32220d08d9 | ||
|
|
29522d81f7 | ||
|
|
97de25ba35 | ||
|
|
45f6c28605 |
@@ -0,0 +1,53 @@
|
||||
"""add cheat_sheet_context to user
|
||||
|
||||
Revision ID: 12h73u00mcwb
|
||||
Revises: a4f23d6b71c8
|
||||
Create Date: 2025-11-13 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "12h73u00mcwb"
|
||||
down_revision = "a4f23d6b71c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"cheat_sheet_context",
|
||||
postgresql.JSONB(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"temporary_user_cheat_sheet_context",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("context", postgresql.JSONB(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("temporary_user_cheat_sheet_context")
|
||||
op.drop_column("user", "cheat_sheet_context")
|
||||
@@ -0,0 +1,97 @@
|
||||
"""add subscription tables
|
||||
|
||||
Revision ID: 13a84b9c2d5f
|
||||
Revises: 12h73u00mcwb
|
||||
Create Date: 2025-12-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "13a84b9c2d5f"
|
||||
down_revision = "12h73u00mcwb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create subscription_registrations table
|
||||
op.create_table(
|
||||
"subscription_registrations",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"doc_extraction_contexts",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"search_questions",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
)
|
||||
|
||||
# Create subscription_results table
|
||||
op.create_table(
|
||||
"subscription_results",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"notifications",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("subscription_results")
|
||||
op.drop_table("subscription_registrations")
|
||||
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool_name = state.tools_used[-1]
|
||||
|
||||
available_tools = state.available_tools
|
||||
if next_tool_name == DRPath.THINKING.value:
|
||||
return DRPath.ORCHESTRATOR # thinking alteady done
|
||||
elif next_tool_name == DRPath.END.value:
|
||||
return END
|
||||
elif not available_tools:
|
||||
raise ValueError("No tool is available. This should not happen.")
|
||||
if next_tool_name in available_tools:
|
||||
next_tool_path = available_tools[next_tool_name].path
|
||||
|
||||
elif next_tool_name == DRPath.LOGGER.value:
|
||||
return DRPath.LOGGER
|
||||
elif next_tool_name == DRPath.CLOSER.value:
|
||||
return DRPath.CLOSER
|
||||
|
||||
else:
|
||||
return DRPath.ORCHESTRATOR
|
||||
|
||||
# handle invalid paths
|
||||
if next_tool_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_tool_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
|
||||
return next_tool_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return DRPath.LOGGER
|
||||
31
backend/onyx/agents/agent_search/exploration/constants.py
Normal file
31
backend/onyx/agents/agent_search/exploration/constants.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 12.0,
|
||||
ResearchType.FAST: 0.5,
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
BASE_SYSTEM_MESSAGE_TEMPLATE = """
|
||||
You are a helpful assistant that can answer questions and help with tasks.
|
||||
|
||||
You should answer the user's question based on their needs. Their needs and directions \
|
||||
are specified as follows:
|
||||
|
||||
###
|
||||
---user_prompt---
|
||||
###
|
||||
|
||||
The current date is ---current_date---.
|
||||
|
||||
|
||||
The answer process may be complex and may involve multiple tool calls. Here is \
|
||||
a description of the tools that MAY be available to you throughout the conversations \
|
||||
(note though that not all tools may be available at all times, depending on the context):
|
||||
|
||||
###
|
||||
---available_tool_descriptions_str---
|
||||
###
|
||||
---cheat_sheet_string---
|
||||
|
||||
Here are some principle reminders about how to answer the user's question:
|
||||
- you will derive the answer through conversational steps with the user.
|
||||
---plan_instruction_insertion---
|
||||
- the answer should only be generated using the tools responses and if applicable, \
|
||||
retrieved documents and information.
|
||||
|
||||
"""
|
||||
|
||||
PLAN_PROMPT_TEMPLATE = """
|
||||
Now you should create a plan how to address the user's question.
|
||||
|
||||
###
|
||||
---user_plan_instructions_prompt---
|
||||
###
|
||||
|
||||
"""
|
||||
|
||||
ORCHESTRATOR_PROMPT_TEMPLATE = """
|
||||
Your need to consider the conversation thus far and see what you want to do next in order \
|
||||
to answer the user's question/task.
|
||||
|
||||
Particularly, you should consider the following:
|
||||
|
||||
- the ORIGINAL QUESTION
|
||||
- the plan you had created early on
|
||||
- the additional context provided in the system prompt, if applicable
|
||||
- the questions generated so far and the corresponding answers you have received
|
||||
- previous documented thinking processes, if any. In particular, if \
|
||||
your previous step in the conversation was a thinking step, pay doubly attention to that \
|
||||
one as it should provide you with clear guaidance for what to do next.
|
||||
- the tools you have available, and the instructions you have been given for each tool, \
|
||||
including how many quesries can be generated in each iteration.
|
||||
|
||||
|
||||
Note:
|
||||
- make sure that you don't repeat yourself. If you have already asked a question of the same tool, \
|
||||
do not ask it again! New questons to the same tool must be substantially different from the previous ones.
|
||||
- a previous question can however be asked of a DIFFERENT tool, if there is reason to believe that the \
|
||||
new tool is suitable for the question.
|
||||
- you must make sure that the tool has ALL RELEVANT context to answer the question/adrres the \
|
||||
request you are posing to it.
|
||||
- NEVER answer the question directly! If you do have all of the information available AND you \
|
||||
do not want to request additional information or make checks, you need to call the CLOSER tool.
|
||||
|
||||
|
||||
Your task is to select the next tool to call and the questions/requests to ask of that tool.
|
||||
|
||||
"""
|
||||
|
||||
EXTRACTION_SYSTEM_PROMPT_TEMPLATE = """
|
||||
You are an expert in identifying relevant information and strategies from extracted facts, thoughts, and \
|
||||
answers provided to a user based on their question, that may be useful in INFORMING FUTURE ANSWER STRATEGIES.
|
||||
As such, your extractions MUST be correct and broadly informative. You will receive the information fro the user
|
||||
in the next message.
|
||||
|
||||
Conceptual examples are:
|
||||
- facts about the user, their team, or their companies that are helpful to provide context for future questions
|
||||
- search strategies
|
||||
- reasonaning strategies
|
||||
|
||||
Please format the extractions as a json dictionary in this format:
|
||||
{{
|
||||
"user_information": {{<type of user information>: <the extracted information about the user>, ...}},
|
||||
"company_information": {{<type of company information>: <the extracted information about the company>, ...}},
|
||||
"search_strategies": {{<type of search strategy>: <the extracted information about the search strategy>, ...}},
|
||||
"reasoning_strategies": {{<type of reasoning strategy>: <the extracted information about the reasoning strategy>, ...}},
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
EXTRACTION_SYSTEM_PROMPT = """
|
||||
You are an expert in identifying relevant information and strategies from extracted facts, thoughts, and \
|
||||
answers provided to a user based on their question, that may be useful in INFORMING FUTURE ANSWER STRATEGIES, and \
|
||||
updating/extending the previous knowledge accordingly. As such, your updates MUST be correct and broadly informative.
|
||||
|
||||
You will receive the original knowledge and the new information from the user in the next message.
|
||||
|
||||
Conceptual examples are:
|
||||
- facts about the user, their team, or their companies that are helpful to provide context for future questions
|
||||
- search strategies
|
||||
- reasonaning strategies
|
||||
|
||||
|
||||
Please format the extractions as a json dictionary in this format:
|
||||
{{
|
||||
"user": [<list of new information about the user, each formatted as a dictionary in this format:
|
||||
{{'type': <essentially, a keyword for the type of infotmation, like 'location', 'interest',... >',
|
||||
'change_type': <'update', 'delete', or 'add'. 'update' should be selected if this type of information \
|
||||
exists in the original knowledge but it should be extended/updated, 'delete' if the information in the \
|
||||
original knowledge is called into question and no final determination can be made, 'add' for a new information type.>,
|
||||
'information': <the actual information to be added, updated, or deleted. Do not do a rewrite, just the new \
|
||||
information.>}}>,.. ],
|
||||
"company": [<list of new information about the user, each formatted as a dictionary in the exact same \
|
||||
format as above, except for the 'type' key, which should be company-specific instead of user-specific'.>],
|
||||
"search_strategy": [<list of new information about the search strategies, each formatted as a dictionary \
|
||||
in the exact same format as above, except for the 'type' key, which should be search-strategy-specific.>],
|
||||
"reasoning_strategy": [<list of new information about the reasoning strategies, each formatted as a \
|
||||
dictionary in the exact same format as above, except for the 'type' key, which should be reasoning-strategy-specific.>],
|
||||
}}
|
||||
|
||||
Note:
|
||||
- make absolutely sure new information about the user is actually ABOUT THE USER WHO IS ASKING THE QUESTION!
|
||||
- similar, make sure that new information about the company is actually ABOUT THE COMPANY OF THE USER WHO ASKS THE QUESTION!
|
||||
- only suggest updates if there is substantially new information that should extend the same type of information \
|
||||
in the original knowledge.
|
||||
- keep the information concise, to the point, and in a way that would be useful to provide context to future questions.
|
||||
|
||||
"""
|
||||
|
||||
CONTEXT_UPDATE_SYSTEM_PROMPT = """
|
||||
You are an expert in updating/modifying previous knowledge as new information becomes available.
|
||||
Your task is to generate the updated information that consists of both, the old and the new information.
|
||||
|
||||
You will receive the original context and the new information from the user in the next message.
|
||||
|
||||
Please responsd with the consolidated information. Keep the information concise, to the point, and in a way \
|
||||
that would be useful to provide context to future questions.
|
||||
|
||||
"""
|
||||
@@ -0,0 +1,112 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.exploration.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
|
||||
tool_differentiations: list[str] = [
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
for tool_1 in available_tools
|
||||
for tool_2 in available_tools
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
]
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
entity_types_string or relationship_types_string
|
||||
):
|
||||
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string or "",
|
||||
possible_relationships=relationship_types_string or "",
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
33
backend/onyx/agents/agent_search/exploration/enums.py
Normal file
33
backend/onyx/agents/agent_search/exploration/enums.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
FAST = "FAST"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
"""Research answer purpose options for agent search operations"""
|
||||
|
||||
ANSWER = "ANSWER"
|
||||
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
CLOSER = "Closer"
|
||||
THINKING = "Thinking"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
@@ -0,0 +1,88 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.exploration.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.exploration.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.exploration.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.exploration.nodes.dr_a3_logger import logging
|
||||
from onyx.agents.agent_search.exploration.states import MainInput
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
dr_generic_internal_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
|
||||
# from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
|
||||
def exploration_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_ws_graph_builder().compile()
|
||||
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.db.hackathon_subscriptions import get_document_ids_by_cc_pair_name
|
||||
from onyx.db.hackathon_subscriptions import get_subscription_registration
|
||||
from onyx.db.hackathon_subscriptions import get_subscription_result
|
||||
from onyx.db.hackathon_subscriptions import save_subscription_result
|
||||
from onyx.db.models import SubscriptionResult
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_notifications(
|
||||
db_session: Session,
|
||||
llm: LLM,
|
||||
user: User | None = None,
|
||||
) -> None:
|
||||
if not user:
|
||||
return
|
||||
|
||||
subscription_registration = get_subscription_registration(db_session, str(user.id))
|
||||
if not subscription_registration:
|
||||
return
|
||||
|
||||
doc_extraction_contexts = subscription_registration.doc_extraction_contexts
|
||||
subscription_registration.search_questions
|
||||
|
||||
# Get the document index for retrieval
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Build access control filters for the user
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
# Get current tenant ID
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
analysis_results = defaultdict(lambda: defaultdict(list))
|
||||
invalid_doc_extraction_context_keys = []
|
||||
|
||||
for (
|
||||
doc_extraction_context_key,
|
||||
doc_extraction_context_value,
|
||||
) in doc_extraction_contexts.items():
|
||||
# Get all document IDs and links for this connector credential pair
|
||||
# Filter for documents updated after September 18, 2025
|
||||
date_threshold = datetime(2025, 9, 19) - timedelta(days=1)
|
||||
documents = get_document_ids_by_cc_pair_name(
|
||||
db_session, doc_extraction_context_key, date_threshold
|
||||
)
|
||||
|
||||
for document_id, document_link in documents:
|
||||
# Retrieve all chunks for this specific document with proper access control
|
||||
filters = IndexFilters(
|
||||
tenant_id=tenant_id,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
document_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[VespaChunkRequest(document_id=document_id)],
|
||||
filters=filters,
|
||||
batch_retrieval=False,
|
||||
)
|
||||
|
||||
# Sort chunks by chunk_id and concatenate content
|
||||
content_chunks = [
|
||||
{chunk.chunk_id: chunk.content} for chunk in document_chunks
|
||||
]
|
||||
sorted_content_chunks = sorted(
|
||||
content_chunks, key=lambda x: list(x.keys())[0]
|
||||
)
|
||||
sorted_content_chunks_string = "\n".join(
|
||||
[
|
||||
f"{chunk_id}: {content}"
|
||||
for chunk_dict in sorted_content_chunks
|
||||
for chunk_id, content in chunk_dict.items()
|
||||
]
|
||||
)
|
||||
|
||||
# Replace placeholder in extraction context with document content
|
||||
analysis_prompt = doc_extraction_context_value.replace(
|
||||
"---doc_content---", sorted_content_chunks_string
|
||||
)
|
||||
|
||||
# Invoke LLM with the analysis prompt
|
||||
analysis_response = invoke_llm_raw(
|
||||
llm=llm,
|
||||
prompt=analysis_prompt,
|
||||
)
|
||||
|
||||
# Parse the response content from string to dictionary
|
||||
response_content = str(analysis_response.content)
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
# First, try to parse the entire response as JSON
|
||||
analysis_dict = json.loads(response_content)
|
||||
except json.JSONDecodeError:
|
||||
# If that fails, try to find JSON within markdown code blocks
|
||||
import re
|
||||
|
||||
json_match = re.search(
|
||||
r"```(?:json)?\s*(\{.*?\})\s*```", response_content, re.DOTALL
|
||||
)
|
||||
if json_match:
|
||||
analysis_dict = json.loads(json_match.group(1))
|
||||
else:
|
||||
# Try to find JSON between curly braces
|
||||
json_match = re.search(r"\{.*\}", response_content, re.DOTALL)
|
||||
if json_match:
|
||||
analysis_dict = json.loads(json_match.group(0))
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to parse LLM response as JSON: {response_content}"
|
||||
)
|
||||
analysis_dict = {}
|
||||
|
||||
for analysis_type, analysis_value in analysis_dict.items():
|
||||
if analysis_value:
|
||||
analysis_results[doc_extraction_context_key][analysis_type].append(
|
||||
f"[{document_id}]({document_link}): {analysis_value}"
|
||||
)
|
||||
if analysis_type == "use" and analysis_value.lower() != "yes":
|
||||
invalid_doc_extraction_context_keys.append(document_id)
|
||||
|
||||
# Build the analysis string from all results
|
||||
analysis_string_components = []
|
||||
|
||||
for doc_extraction_context_key, analysis_type_dict in analysis_results.items():
|
||||
if doc_extraction_context_key in invalid_doc_extraction_context_keys:
|
||||
continue
|
||||
for analysis_type, analysis_values in analysis_type_dict.items():
|
||||
if analysis_type != "use":
|
||||
analysis_string_components.append(f"## Calls - {analysis_type}")
|
||||
for analysis_value in analysis_values:
|
||||
# Check if the analysis_value is valid and doesn't contain any invalid document IDs
|
||||
if (
|
||||
analysis_value
|
||||
and not any(
|
||||
invalid_id in analysis_value
|
||||
for invalid_id in invalid_doc_extraction_context_keys
|
||||
)
|
||||
and analysis_type != "use"
|
||||
):
|
||||
analysis_string_components.append(analysis_value)
|
||||
analysis_string = "\n \n \n ".join(analysis_string_components)
|
||||
|
||||
# Save the results to the subscription_results table
|
||||
subscription_result = SubscriptionResult(
|
||||
user_id=user.id,
|
||||
type="document_analysis",
|
||||
notifications={"analysis": analysis_string},
|
||||
)
|
||||
save_subscription_result(db_session, subscription_result)
|
||||
|
||||
|
||||
def get_notifications(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
subscription_result = get_subscription_result(db_session, str(user.id))
|
||||
if not subscription_result:
|
||||
return
|
||||
return subscription_result.notifications["analysis"]
|
||||
147
backend/onyx/agents/agent_search/exploration/models.py
Normal file
147
backend/onyx/agents/agent_search/exploration/models.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
|
||||
GeneratedImage,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
|
||||
|
||||
class DecisionResponse(BaseModel):
|
||||
reasoning: str
|
||||
decision: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
|
||||
|
||||
class WebSearchAnswer(BaseModel):
|
||||
urls_to_open_indices: list[int]
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
is_internet_marker_dict: dict[str, bool]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
|
||||
|
||||
# EXPLORATION TESTING
|
||||
|
||||
|
||||
class CheatSheetContext(BaseModel):
|
||||
history: Dict[str, str]
|
||||
user_context: Dict[str, str]
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
user: list[Dict[str, str]]
|
||||
company: list[Dict[str, str]]
|
||||
search_strategy: list[Dict[str, str]]
|
||||
reasoning_strategy: list[Dict[str, str]]
|
||||
@@ -0,0 +1,716 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.exploration.constants import AVERAGE_TOOL_COSTS
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
|
||||
BASE_SYSTEM_MESSAGE_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
|
||||
PLAN_PROMPT_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.exploration.hackathon_functions import get_notifications
|
||||
from onyx.agents.agent_search.exploration.hackathon_functions import (
|
||||
process_notifications,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.exploration.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.exploration.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
from onyx.agents.agent_search.exploration.states import OrchestrationSetup
|
||||
from onyx.agents.agent_search.exploration.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import build_citation_map_from_numbers
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.memories import get_memories
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.exploration_research_configs import (
|
||||
EXPLORATION_TEST_USE_CALRIFIER_DEFAULT,
|
||||
)
|
||||
from onyx.configs.exploration_research_configs import (
|
||||
EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT,
|
||||
)
|
||||
from onyx.configs.exploration_research_configs import EXPLORATION_TEST_USE_PLAN_DEFAULT
|
||||
from onyx.configs.exploration_research_configs import (
|
||||
EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT,
|
||||
)
|
||||
from onyx.configs.exploration_research_configs import (
|
||||
EXPLORATION_TEST_USE_THINKING_DEFAULT,
|
||||
)
|
||||
from onyx.db.chat import create_search_doc_from_saved_search_doc
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.db.users import get_user_cheat_sheet_context
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_memories
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_PLAN_INSTRUCTION_INSERTION = """ - Early on, the user MAY ask you to create a plan for the answer process. \
|
||||
Think about the tools you have available and how you can use them, and then \
|
||||
create a HIGH-LEVEL PLAN of how you want to approach the answer process."""
|
||||
|
||||
|
||||
def _get_available_tools(
|
||||
db_session: Session,
|
||||
graph_config: GraphConfig,
|
||||
kg_enabled: bool,
|
||||
active_source_types: list[DocumentSource],
|
||||
use_clarifier: bool = False,
|
||||
use_thinking: bool = False,
|
||||
) -> dict[str, OrchestratorTool]:
|
||||
|
||||
available_tools: dict[str, OrchestratorTool] = {}
|
||||
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
persona = graph_config.inputs.persona
|
||||
|
||||
if persona:
|
||||
include_kg = persona.name == TMP_DRALPHA_PERSONA_NAME and kg_enabled
|
||||
else:
|
||||
include_kg = False
|
||||
|
||||
tool_dict: dict[int, Tool] = {
|
||||
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
|
||||
}
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
if not tool.is_available(db_session):
|
||||
logger.info(f"Tool {tool.name} is not available, skipping")
|
||||
continue
|
||||
|
||||
tool_db_info = tool_dict.get(tool.id)
|
||||
if tool_db_info:
|
||||
incode_tool_id = tool_db_info.in_code_tool_id
|
||||
else:
|
||||
raise ValueError(f"Tool {tool.name} is not found in the database")
|
||||
|
||||
if isinstance(tool, WebSearchTool):
|
||||
llm_path = DRPath.WEB_SEARCH.value
|
||||
path = DRPath.WEB_SEARCH
|
||||
elif isinstance(tool, SearchTool):
|
||||
llm_path = DRPath.INTERNAL_SEARCH.value
|
||||
path = DRPath.INTERNAL_SEARCH
|
||||
elif isinstance(tool, KnowledgeGraphTool) and include_kg:
|
||||
# TODO (chris): move this into the `is_available` check
|
||||
if len(active_source_types) == 0:
|
||||
logger.error(
|
||||
"No active source types found, skipping Knowledge Graph tool"
|
||||
)
|
||||
continue
|
||||
llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
llm_path = DRPath.IMAGE_GENERATION.value
|
||||
path = DRPath.IMAGE_GENERATION
|
||||
elif incode_tool_id:
|
||||
# if incode tool id is found, it is a generic internal tool
|
||||
llm_path = DRPath.GENERIC_INTERNAL_TOOL.value
|
||||
path = DRPath.GENERIC_INTERNAL_TOOL
|
||||
else:
|
||||
# otherwise it is a custom tool
|
||||
llm_path = DRPath.GENERIC_TOOL.value
|
||||
path = DRPath.GENERIC_TOOL
|
||||
|
||||
if path not in {DRPath.GENERIC_INTERNAL_TOOL, DRPath.GENERIC_TOOL}:
|
||||
description = TOOL_DESCRIPTION.get(path, tool.description)
|
||||
cost = AVERAGE_TOOL_COSTS[path]
|
||||
else:
|
||||
description = tool.description
|
||||
cost = 1.0
|
||||
|
||||
tool_info = OrchestratorTool(
|
||||
tool_id=tool.id,
|
||||
name=tool.llm_name,
|
||||
llm_path=llm_path,
|
||||
path=path,
|
||||
description=description,
|
||||
metadata={},
|
||||
cost=cost,
|
||||
tool_object=tool,
|
||||
)
|
||||
|
||||
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
|
||||
available_tools[tool.llm_name] = tool_info
|
||||
|
||||
available_tool_paths = [tool.path for tool in available_tools.values()]
|
||||
|
||||
# make sure KG isn't enabled without internal search
|
||||
if (
|
||||
DRPath.KNOWLEDGE_GRAPH in available_tool_paths
|
||||
and DRPath.INTERNAL_SEARCH not in available_tool_paths
|
||||
):
|
||||
raise ValueError(
|
||||
"The Knowledge Graph is not supported without internal search tool"
|
||||
)
|
||||
|
||||
# add CLOSER tool, which is always available
|
||||
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
|
||||
tool_id=-1,
|
||||
name=DRPath.CLOSER.value,
|
||||
llm_path=DRPath.CLOSER.value,
|
||||
path=DRPath.CLOSER,
|
||||
description=TOOL_DESCRIPTION[DRPath.CLOSER],
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
if use_thinking:
|
||||
available_tools[DRPath.THINKING.value] = OrchestratorTool(
|
||||
tool_id=102,
|
||||
name=DRPath.THINKING.value,
|
||||
llm_path=DRPath.THINKING.value,
|
||||
path=DRPath.THINKING,
|
||||
description="""This tool should be used if the next step is not particularly clear, \
|
||||
or if you think you need to think through the original question and the questions and answers \
|
||||
you have received so far in order to make a decision about what to do next AMONGST THE TOOLS AVAILABLE TO YOU \
|
||||
IN THIS REQUEST! (Note: some tools described earlier may be excluded!).
|
||||
If in doubt, use this tool. No action will be taken, just some reasoning will be done.""",
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
if use_clarifier:
|
||||
available_tools[DRPath.CLARIFIER.value] = OrchestratorTool(
|
||||
tool_id=103,
|
||||
name=DRPath.CLARIFIER.value,
|
||||
llm_path=DRPath.CLARIFIER.value,
|
||||
path=DRPath.CLARIFIER,
|
||||
description="""This tool should be used ONLY if you need to have clarification on something IMPORTANT FROM \
|
||||
the user. This can pertain to the original question or something you found out during the process so far.""",
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
return available_tools
|
||||
|
||||
|
||||
def _construct_uploaded_text_context(files: list[InMemoryChatFile]) -> str:
|
||||
"""Construct the uploaded context from the files."""
|
||||
file_contents = []
|
||||
for file in files:
|
||||
if file.file_type in (
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.CSV,
|
||||
):
|
||||
file_contents.append(file.content.decode("utf-8"))
|
||||
if len(file_contents) > 0:
|
||||
return "Uploaded context:\n\n\n" + "\n\n".join(file_contents)
|
||||
return ""
|
||||
|
||||
|
||||
def _construct_uploaded_image_context(
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Construct the uploaded image context from the files."""
|
||||
# Only include image files for user messages
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
if not (img_files or img_urls or b64_imgs):
|
||||
return None
|
||||
|
||||
return cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
},
|
||||
}
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url in img_urls
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _get_existing_clarification_request(
|
||||
graph_config: GraphConfig,
|
||||
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
|
||||
"""
|
||||
Returns the clarification info, original question, and updated chat history if
|
||||
a clarification request and response exists, otherwise returns None.
|
||||
"""
|
||||
# check for clarification request and response in message history
|
||||
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
|
||||
|
||||
if len(previous_raw_messages) == 0 or (
|
||||
previous_raw_messages[-1].research_answer_purpose
|
||||
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
|
||||
):
|
||||
return None
|
||||
|
||||
# get the clarification request and response
|
||||
previous_messages = graph_config.inputs.prompt_builder.message_history
|
||||
last_message = previous_raw_messages[-1].message
|
||||
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=last_message.strip(),
|
||||
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
chat_history_string = "(No chat history yet available)"
|
||||
|
||||
# get the original user query and chat history string before the original query
|
||||
# e.g., if history = [user query, assistant clarification request, user clarification response],
|
||||
# previous_messages = [user query, assistant clarification request], we want the user query
|
||||
for i, message in enumerate(reversed(previous_messages), 1):
|
||||
if (
|
||||
isinstance(message, HumanMessage)
|
||||
and message.content
|
||||
and isinstance(message.content, str)
|
||||
):
|
||||
original_question = message.content
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history[:-i],
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
break
|
||||
|
||||
return clarification, original_question, chat_history_string
|
||||
|
||||
|
||||
def _persist_final_docs_and_citations(
|
||||
db_session: Session,
|
||||
context_llm_docs: list[Any] | None,
|
||||
full_answer: str | None,
|
||||
) -> tuple[list[SearchDoc], dict[int, int] | None]:
|
||||
"""Persist final documents from in-context docs and derive citation mapping.
|
||||
|
||||
Returns the list of persisted `SearchDoc` records and an optional
|
||||
citation map translating inline [[n]] references to DB doc indices.
|
||||
"""
|
||||
final_documents_db: list[SearchDoc] = []
|
||||
citations_map: dict[int, int] | None = None
|
||||
|
||||
if not context_llm_docs:
|
||||
return final_documents_db, citations_map
|
||||
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
|
||||
for saved_doc in saved_search_docs:
|
||||
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
|
||||
db_session.add(db_doc)
|
||||
final_documents_db.append(db_doc)
|
||||
db_session.flush()
|
||||
|
||||
cited_numbers: set[int] = set()
|
||||
try:
|
||||
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
|
||||
matches = re.findall(
|
||||
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
|
||||
)
|
||||
for match in matches:
|
||||
for num_str in match.split(","):
|
||||
num = int(num_str.strip())
|
||||
cited_numbers.add(num)
|
||||
except Exception:
|
||||
cited_numbers = set()
|
||||
|
||||
if cited_numbers and final_documents_db:
|
||||
translations = build_citation_map_from_numbers(
|
||||
cited_numbers=cited_numbers,
|
||||
db_docs=final_documents_db,
|
||||
)
|
||||
citations_map = translations or None
|
||||
|
||||
return final_documents_db, citations_map
|
||||
|
||||
|
||||
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_any_knowledge_retrieval_and_any_action_tool",
|
||||
"description": "Use this tool to get ANY external information \
|
||||
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
|
||||
ANY tool mentioned can be accessed through this generic tool. If in doubt, use this tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The request to be made to the tool",
|
||||
},
|
||||
},
|
||||
"required": ["request"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def clarifier(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationSetup:
|
||||
"""
|
||||
Perform a quick search on the question as is and see whether a set of clarification
|
||||
questions is needed. For now this is based on the models
|
||||
"""
|
||||
|
||||
_EXPLORATION_TEST_USE_CALRIFIER = EXPLORATION_TEST_USE_CALRIFIER_DEFAULT
|
||||
_EXPLORATION_TEST_USE_PLAN = EXPLORATION_TEST_USE_PLAN_DEFAULT
|
||||
_EXPLORATION_TEST_USE_PLAN_UPDATES = EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT
|
||||
_EXPLORATION_TEST_USE_CORPUS_HISTORY = EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT
|
||||
_EXPLORATION_TEST_USE_THINKING = EXPLORATION_TEST_USE_THINKING_DEFAULT
|
||||
|
||||
_EXPLORATION_TEST_USE_PLAN = False
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = 0
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=llm_model_name,
|
||||
model_provider=llm_provider,
|
||||
)
|
||||
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
# Perform a commit to ensure the message_id is set and saved
|
||||
db_session.commit()
|
||||
|
||||
# get the connected tools and format for the Deep Research flow
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
active_source_types = fetch_unique_document_sources(db_session)
|
||||
|
||||
available_tools = _get_available_tools(
|
||||
db_session,
|
||||
graph_config,
|
||||
kg_enabled,
|
||||
active_source_types,
|
||||
use_clarifier=_EXPLORATION_TEST_USE_CALRIFIER,
|
||||
use_thinking=_EXPLORATION_TEST_USE_THINKING,
|
||||
)
|
||||
|
||||
available_tool_descriptions_str = "\n -" + "\n -".join(
|
||||
[
|
||||
tool.name + ": " + tool.description
|
||||
for tool in available_tools.values()
|
||||
if tool.path != DRPath.CLOSER
|
||||
]
|
||||
)
|
||||
|
||||
active_source_types_descriptions = [
|
||||
DocumentSourceDescription[source_type] for source_type in active_source_types
|
||||
]
|
||||
|
||||
if len(active_source_types_descriptions) > 0:
|
||||
active_source_type_descriptions_str = "\n -" + "\n -".join(
|
||||
active_source_types_descriptions
|
||||
)
|
||||
else:
|
||||
active_source_type_descriptions_str = ""
|
||||
|
||||
if graph_config.inputs.persona:
|
||||
assistant_system_prompt = PromptTemplate(
|
||||
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
|
||||
).build()
|
||||
if graph_config.inputs.persona.task_prompt:
|
||||
assistant_task_prompt = (
|
||||
"\n\nHere are more specifications from the user:\n\n"
|
||||
+ PromptTemplate(graph_config.inputs.persona.task_prompt).build()
|
||||
)
|
||||
else:
|
||||
assistant_task_prompt = ""
|
||||
|
||||
else:
|
||||
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
|
||||
assistant_task_prompt = ""
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
if graph_config.tooling.search_tool
|
||||
else None
|
||||
)
|
||||
|
||||
continue_to_answer = True
|
||||
if original_question == "process_notifications" and user:
|
||||
process_notifications(db_session, llm=graph_config.tooling.fast_llm, user=user)
|
||||
continue_to_answer = False
|
||||
|
||||
# Stream the notifications message
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(content="", final_documents=None),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageDelta(content="Done!"),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
elif original_question == "get_notifications" and user:
|
||||
notifications = get_notifications(db_session, user)
|
||||
if notifications:
|
||||
# Stream the notifications message
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(content="", final_documents=None),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageDelta(content=notifications),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
continue_to_answer = False
|
||||
|
||||
if not continue_to_answer:
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
iteration_nr=0,
|
||||
current_step_nr=current_step_nr,
|
||||
)
|
||||
|
||||
memories = get_memories(user, db_session)
|
||||
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
|
||||
assistant_system_prompt = handle_memories(assistant_system_prompt, memories)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
uploaded_text_context = (
|
||||
_construct_uploaded_text_context(graph_config.inputs.files)
|
||||
if graph_config.inputs.files
|
||||
else ""
|
||||
)
|
||||
|
||||
uploaded_context_tokens = check_number_of_tokens(
|
||||
uploaded_text_context, llm_tokenizer.encode
|
||||
)
|
||||
|
||||
if uploaded_context_tokens > 0.5 * max_input_tokens:
|
||||
raise ValueError(
|
||||
f"Uploaded context is too long. {uploaded_context_tokens} tokens, "
|
||||
f"but for this model we only allow {0.5 * max_input_tokens} tokens for uploaded context"
|
||||
)
|
||||
|
||||
uploaded_image_context = _construct_uploaded_image_context(
|
||||
graph_config.inputs.files
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
clarification = None
|
||||
|
||||
message_history_for_continuation: list[SystemMessage | HumanMessage | AIMessage] = (
|
||||
[]
|
||||
)
|
||||
|
||||
if user is not None:
|
||||
original_cheat_sheet_context = get_user_cheat_sheet_context(
|
||||
user=user, db_session=db_session
|
||||
)
|
||||
|
||||
if original_cheat_sheet_context:
|
||||
cheat_sheet_string = f"""\n\nHere is additional context learned that may inform the \
|
||||
process (plan generation if applicable, reasoning, tool calls, etc.):\n{str(original_cheat_sheet_context)}\n###\n\n"""
|
||||
else:
|
||||
cheat_sheet_string = ""
|
||||
|
||||
if _EXPLORATION_TEST_USE_PLAN:
|
||||
plan_instruction_insertion = _PLAN_INSTRUCTION_INSERTION
|
||||
else:
|
||||
plan_instruction_insertion = ""
|
||||
|
||||
system_message = (
|
||||
BASE_SYSTEM_MESSAGE_TEMPLATE.replace(
|
||||
"---user_prompt---", assistant_system_prompt
|
||||
)
|
||||
.replace("---current_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace(
|
||||
"---available_tool_descriptions_str---", available_tool_descriptions_str
|
||||
)
|
||||
.replace(
|
||||
"---active_source_types_descriptions_str---",
|
||||
active_source_type_descriptions_str,
|
||||
)
|
||||
.replace(
|
||||
"---cheat_sheet_string---",
|
||||
cheat_sheet_string,
|
||||
)
|
||||
.replace("---plan_instruction_insertion---", plan_instruction_insertion)
|
||||
)
|
||||
|
||||
message_history_for_continuation.append(SystemMessage(content=system_message))
|
||||
message_history_for_continuation.append(
|
||||
HumanMessage(
|
||||
content=f"""Here is the questions to answer:\n{original_question}"""
|
||||
)
|
||||
)
|
||||
message_history_for_continuation.append(
|
||||
AIMessage(content="""How should I proceed to answer the question?""")
|
||||
)
|
||||
|
||||
if _EXPLORATION_TEST_USE_PLAN:
|
||||
|
||||
user_plan_instructions_prompt = """Think carefully how you want to address the question. You may use multiple iterations \
|
||||
of tool calls, reasoning, etc.
|
||||
|
||||
Note:
|
||||
|
||||
- the plan should be HIGH-LEVEL! Do not specify any specific tools, but think about what you want to learn in each iteration.
|
||||
- if the question is simple, one iteration may be enough.
|
||||
- DO NOT close with 'summarize...' etc as the last steps. Just focus on the information gathering steps.
|
||||
|
||||
"""
|
||||
|
||||
plan_prompt = PLAN_PROMPT_TEMPLATE.replace(
|
||||
"---user_plan_instructions_prompt---", user_plan_instructions_prompt
|
||||
)
|
||||
|
||||
message_history_for_continuation.append(HumanMessage(content=plan_prompt))
|
||||
|
||||
plan_of_record = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=message_history_for_continuation,
|
||||
schema=OrchestrationPlan,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
|
||||
plan_string = f"""Here is how the answer process should be broken down: {plan_of_record.plan}"""
|
||||
|
||||
message_history_for_continuation.append(AIMessage(content=plan_string))
|
||||
|
||||
next_tool = DRPath.ORCHESTRATOR.value
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
tools_used=[next_tool],
|
||||
query_list=[],
|
||||
iteration_nr=0,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="clarifier",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
available_tools=available_tools,
|
||||
active_source_types=active_source_types,
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
uploaded_test_context=uploaded_text_context,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
message_history_for_continuation=message_history_for_continuation,
|
||||
cheat_sheet_context=original_cheat_sheet_context,
|
||||
use_clarifier=_EXPLORATION_TEST_USE_CALRIFIER,
|
||||
use_thinking=_EXPLORATION_TEST_USE_THINKING,
|
||||
use_plan=_EXPLORATION_TEST_USE_PLAN,
|
||||
use_plan_updates=_EXPLORATION_TEST_USE_PLAN_UPDATES,
|
||||
use_corpus_history=_EXPLORATION_TEST_USE_CORPUS_HISTORY,
|
||||
)
|
||||
@@ -0,0 +1,325 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
|
||||
ORCHESTRATOR_PROMPT_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.states import IterationInstructions
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
from onyx.agents.agent_search.exploration.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_SEARCH_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_tool",
|
||||
"description": "This tool is the search tool whose functionality and details are \
|
||||
described in he system prompt. Use it if you think you have one or more questions that you believe are \
|
||||
suitable for the search tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "array",
|
||||
"description": "The list of questions to be asked of the search tool",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "The question to be asked of the search tool",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["request"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_THINKING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "thinking_tool",
|
||||
"description": "This tool is used if yoi think you need to think through the original question and the \
|
||||
questions and answers you have received so far in order to male a decision about what to do next. If in doubt, use this tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "Please generate the thinking here you want to do that leads you to the next decision. This \
|
||||
should end with a recommendation of which tool to invoke next.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_CLOSER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "closer_tool",
|
||||
"description": "This tool is used to close the conversation. Use it if you think you have \
|
||||
all of the information you need to answer the question, and you also do not want to request additional \
|
||||
information or make checks.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The request to be made to the thinking tool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_CLARIFIER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "clarifier_tool",
|
||||
"description": "This tool is used if you need to have clarification on something IMPORTANT from \
|
||||
the user. This can pertain to the original question or something you found out during the process so far.",
|
||||
},
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The question you would like to ask the user to get clarification.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_DECISION_SYSTEM_PROMPT_PREFIX = "Here are general instructions by the user, which \
|
||||
may or may not influence the decision what to do next:\n\n"
|
||||
|
||||
|
||||
def _get_implied_next_tool_based_on_tool_call_history(
|
||||
tools_used: list[str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Identify the next tool based on the tool call history. Initially, we only support
|
||||
special handling of the image generation tool.
|
||||
"""
|
||||
if tools_used[-1] == DRPath.IMAGE_GENERATION.value:
|
||||
return DRPath.LOGGER.value
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def orchestrator(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to decide the next step in the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
_EXPLORATION_TEST_USE_CALRIFIER = state.use_clarifier
|
||||
state.use_plan
|
||||
state.use_plan_updates
|
||||
state.use_corpus_history
|
||||
_EXPLORATION_TEST_USE_THINKING = state.use_thinking
|
||||
|
||||
previous_tool_call_name = state.tools_used[-1] if state.tools_used else ""
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.original_question
|
||||
if not question:
|
||||
raise ValueError("Question is required for orchestrator")
|
||||
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
plan_of_record = state.plan_of_record
|
||||
|
||||
message_history_for_continuation = state.message_history_for_continuation
|
||||
new_messages_for_continuation: list[SystemMessage | HumanMessage | AIMessage] = []
|
||||
|
||||
if iteration_nr > 0:
|
||||
last_iteration_responses = [
|
||||
x
|
||||
for x in state.iteration_responses
|
||||
if x.iteration_nr == iteration_nr
|
||||
and x.tool != DRPath.CLARIFIER.value
|
||||
and x.tool != DRPath.THINKING.value
|
||||
]
|
||||
if last_iteration_responses:
|
||||
response_wrapper = f"For the previous iteration {iteration_nr}, here are the tool calls I decided to execute, \
|
||||
the questions and tasks posed, and responses:\n\n"
|
||||
for last_iteration_response in last_iteration_responses:
|
||||
response_wrapper += f"{last_iteration_response.tool}: {last_iteration_response.question}\n"
|
||||
response_wrapper += f"Response: {last_iteration_response.answer}\n\n"
|
||||
|
||||
message_history_for_continuation.append(AIMessage(content=response_wrapper))
|
||||
new_messages_for_continuation.append(AIMessage(content=response_wrapper))
|
||||
|
||||
iteration_nr += 1
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
ResearchType.DEEP
|
||||
remaining_time_budget = state.remaining_time_budget
|
||||
state.chat_history_string or "(No chat history yet available)"
|
||||
|
||||
next_tool_name = None
|
||||
|
||||
# Identify early exit condition based on tool call history
|
||||
|
||||
next_tool_based_on_tool_call_history = (
|
||||
_get_implied_next_tool_based_on_tool_call_history(state.tools_used)
|
||||
)
|
||||
|
||||
if next_tool_based_on_tool_call_history == DRPath.LOGGER.value:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.LOGGER.value],
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning="",
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# no early exit forced. Continue.
|
||||
|
||||
state.available_tools or {}
|
||||
|
||||
state.uploaded_test_context or ""
|
||||
state.uploaded_image_context or []
|
||||
|
||||
# default to closer
|
||||
query_list = ["Answer the question with the information you have."]
|
||||
|
||||
reasoning_result = "(No reasoning result provided yet.)"
|
||||
|
||||
ORCHESTRATOR_PROMPT = ORCHESTRATOR_PROMPT_TEMPLATE
|
||||
|
||||
message_history_for_continuation.append(HumanMessage(content=ORCHESTRATOR_PROMPT))
|
||||
new_messages_for_continuation.append(HumanMessage(content=ORCHESTRATOR_PROMPT))
|
||||
|
||||
tools = [_SEARCH_TOOL, _CLOSER_TOOL]
|
||||
|
||||
if (
|
||||
_EXPLORATION_TEST_USE_THINKING
|
||||
and previous_tool_call_name != DRPath.THINKING.value
|
||||
):
|
||||
tools.append(_THINKING_TOOL)
|
||||
|
||||
if (
|
||||
_EXPLORATION_TEST_USE_CALRIFIER
|
||||
and previous_tool_call_name != DRPath.CLARIFIER.value
|
||||
):
|
||||
tools.append(_CLARIFIER_TOOL)
|
||||
|
||||
in_orchestration_iteration_answers: list[IterationAnswer] = []
|
||||
if remaining_time_budget > 0:
|
||||
|
||||
orchestrator_action: AIMessage = invoke_llm_raw(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=message_history_for_continuation,
|
||||
tools=tools,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
tool_calls = orchestrator_action.tool_calls
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call["name"] == "search_tool":
|
||||
query_list = tool_call["args"]["request"]
|
||||
next_tool_name = DRPath.INTERNAL_SEARCH.value
|
||||
elif tool_call["name"] == "thinking_tool":
|
||||
reasoning_result = tool_call["args"]["request"]
|
||||
next_tool_name = (
|
||||
DRPath.THINKING.value
|
||||
) # note: thinking already done. Will return to Orchestrator.
|
||||
message_history_for_continuation.append(
|
||||
AIMessage(content=reasoning_result)
|
||||
)
|
||||
new_messages_for_continuation.append(
|
||||
AIMessage(content=reasoning_result)
|
||||
)
|
||||
|
||||
in_orchestration_iteration_answers.append(
|
||||
IterationAnswer(
|
||||
tool=DRPath.THINKING.value,
|
||||
tool_id=102,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=0,
|
||||
question="",
|
||||
cited_documents={},
|
||||
answer=reasoning_result,
|
||||
reasoning=reasoning_result,
|
||||
)
|
||||
)
|
||||
|
||||
elif tool_call["name"] == "closer_tool":
|
||||
reasoning_result = "Time to wrap up."
|
||||
next_tool_name = DRPath.CLOSER.value
|
||||
elif tool_call["name"] == "clarifier_tool":
|
||||
reasoning_result = tool_call["args"]["request"]
|
||||
next_tool_name = DRPath.CLARIFIER.value
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {tool_call['name']}")
|
||||
|
||||
else:
|
||||
reasoning_result = "Time to wrap up. All information is available"
|
||||
new_messages_for_continuation.append(AIMessage(content=reasoning_result))
|
||||
next_tool_name = DRPath.CLOSER.value
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[next_tool_name or ""],
|
||||
query_list=query_list or [],
|
||||
iteration_nr=iteration_nr,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget - 1.0,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning=reasoning_result,
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
message_history_for_continuation=new_messages_for_continuation,
|
||||
iteration_responses=in_orchestration_iteration_answers,
|
||||
)
|
||||
@@ -0,0 +1,431 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.exploration.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.exploration.states import FinalUpdate
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
from onyx.agents.agent_search.exploration.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import aggregate_context
|
||||
from onyx.agents.agent_search.exploration.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.exploration.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.exploration.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
|
||||
|
||||
def insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
state.use_clarifier
|
||||
state.use_plan
|
||||
state.use_plan_updates
|
||||
state.use_corpus_history
|
||||
state.use_thinking
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = ResearchType.DEEP
|
||||
|
||||
assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
aggregated_context_w_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
aggregated_context_wo_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=False
|
||||
)
|
||||
|
||||
iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
all_cited_documents
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
elif research_type == ResearchType.DEEP:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
else:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
|
||||
estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_w_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
)
|
||||
|
||||
# for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# Should generally not be the case though.
|
||||
|
||||
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
|
||||
if (
|
||||
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
iteration_responses_string = iteration_responses_wo_docs_string
|
||||
else:
|
||||
iteration_responses_string = iteration_responses_w_docs_string
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ (graph_config.inputs.project_instructions or "")
|
||||
)
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
int(3 * TF_DR_TIMEOUT_LONG),
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
final_answer_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
replace_citations=True,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
final_answer = "".join(streamed_output)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
# Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# )
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,326 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
|
||||
EXTRACTION_SYSTEM_PROMPT,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.enums import DRPath
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.states import MainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import aggregate_context
|
||||
from onyx.agents.agent_search.exploration.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.db.users import update_user_cheat_sheet_context
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
num_tokens: int,
|
||||
new_cheat_sheet_context: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
if graph_config.tooling.search_tool and graph_config.tooling.search_tool.user:
|
||||
user = fetch_user_by_id(db_session, graph_config.tooling.search_tool.user.id)
|
||||
if new_cheat_sheet_context and user:
|
||||
update_user_cheat_sheet_context(
|
||||
user=user,
|
||||
new_cheat_sheet_context=new_cheat_sheet_context,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def logging(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
base_cheat_sheet_context = state.cheat_sheet_context or {}
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
final_answer = state.final_answer or ""
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
# build extraction context
|
||||
|
||||
extracted_facts: list[str] = []
|
||||
for iteration_response in state.iteration_responses:
|
||||
if iteration_response.tool == DRPath.INTERNAL_SEARCH.value:
|
||||
extracted_facts.extend(iteration_response.claims)
|
||||
|
||||
extraction_context = (
|
||||
"Extracted facts: \n - "
|
||||
+ "\n - ".join(extracted_facts)
|
||||
+ f"\n\nProvidede Answer: \n\n{final_answer}"
|
||||
)
|
||||
|
||||
extraction_system_prompt = SystemMessage(content=EXTRACTION_SYSTEM_PROMPT)
|
||||
extraction_human_prompt = HumanMessage(content=extraction_context)
|
||||
extraction_prompt = [extraction_system_prompt, extraction_human_prompt]
|
||||
|
||||
extraction_response = invoke_llm_raw(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=extraction_prompt,
|
||||
# schema=ExtractionResponse,
|
||||
)
|
||||
|
||||
extraction_information = json.loads(extraction_response.content)
|
||||
|
||||
consolidated_updates: dict[str, list[tuple[str, str]]] = {}
|
||||
|
||||
for area in ["user", "company", "search_strategy", "reasoning_strategy"]:
|
||||
|
||||
update_knowledge = extraction_information.get(area, {})
|
||||
base_area_knowledge = base_cheat_sheet_context.get(area, {})
|
||||
|
||||
if area not in base_cheat_sheet_context:
|
||||
base_cheat_sheet_context[area] = {}
|
||||
|
||||
for update_info in update_knowledge:
|
||||
update_type = update_info.get("type", "n/a")
|
||||
change_type = update_info.get("change_type", "n/a")
|
||||
information = update_info.get("information", "n/a")
|
||||
|
||||
if update_type in base_area_knowledge:
|
||||
if change_type == "update":
|
||||
consolidated_updates[area].append(
|
||||
(base_area_knowledge.get(update_type, ""), information)
|
||||
)
|
||||
|
||||
elif change_type == "delete":
|
||||
del base_cheat_sheet_context.get(area, {})[update_type]
|
||||
elif change_type == "add":
|
||||
base_cheat_sheet_context.get(area, {})[update_type] = information
|
||||
else:
|
||||
base_cheat_sheet_context[area][update_type] = information
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(
|
||||
state,
|
||||
graph_config,
|
||||
aggregated_context,
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
num_tokens,
|
||||
new_cheat_sheet_context=base_cheat_sheet_context,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="logger",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
search_results: list[LlmDoc] | None = None,
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=search_results,
|
||||
doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# Accumulate citation infos if handler emits them
|
||||
collected_citation_infos: list[CitationInfo] = []
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message):
|
||||
|
||||
# only stream out answer parts
|
||||
if (
|
||||
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
and generate_final_answer
|
||||
and response_part.answer_piece
|
||||
):
|
||||
if chat_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_message_id is required when generating final answer"
|
||||
)
|
||||
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageStart(content="", final_documents=saved_search_docs),
|
||||
writer,
|
||||
)
|
||||
start_final_answer_streaming_set = True
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=response_part.answer_piece),
|
||||
writer,
|
||||
)
|
||||
# collect citation info objects
|
||||
elif isinstance(response_part, CitationInfo):
|
||||
collected_citation_infos.append(response_part)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
write_custom_event(
|
||||
ind,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Emit citations section if any were collected
|
||||
if collected_citation_infos:
|
||||
write_custom_event(ind, CitationStart(), writer)
|
||||
write_custom_event(
|
||||
ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
)
|
||||
write_custom_event(ind, SectionEnd(), writer)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
99
backend/onyx/agents/agent_search/exploration/states.py
Normal file
99
backend/onyx/agents/agent_search/exploration/states.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.models import IterationInstructions
|
||||
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.exploration.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.exploration.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
message_history_for_continuation: Annotated[
|
||||
list[SystemMessage | HumanMessage | AIMessage], add
|
||||
] = []
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class OrchestrationSetup(OrchestrationUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
message_history_for_continuation: Annotated[
|
||||
list[SystemMessage | HumanMessage | AIMessage], add
|
||||
] = []
|
||||
cheat_sheet_context: Dict[str, Any] | None = None
|
||||
use_clarifier: bool = False
|
||||
use_thinking: bool = False
|
||||
use_plan: bool = False
|
||||
use_plan_updates: bool = False
|
||||
use_corpus_history: bool = False
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationSetup,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.models import SearchAnswer
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.exploration.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = ResearchType.DEEP
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
elif len(state.tools_used) == 0:
|
||||
raise ValueError("tools_used is empty")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
)
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
|
||||
# give back the query so we can render it in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[rewritten_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = (
|
||||
strings_to_document_sources(search_processing.specified_source_types)
|
||||
if search_processing.specified_source_types
|
||||
else None
|
||||
)
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
if force_use_tool.override_kwargs and isinstance(
|
||||
force_use_tool.override_kwargs, SearchToolOverrideKwargs
|
||||
):
|
||||
override_kwargs = force_use_tool.override_kwargs
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
project_id = override_kwargs.project_id
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
|
||||
break
|
||||
|
||||
# render the retrieved docs in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
retrieved_docs, is_internet=False
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
if citation_numbers and (
|
||||
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
):
|
||||
raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
[update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
|
||||
for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
retrieved_saved_search_doc.is_internet = False
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Web Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,30 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
current_step_nr=state.current_step_nr,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.name
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=custom_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
# summarise tool result
|
||||
if not response_summary.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
# get file_ids:
|
||||
file_ids = None
|
||||
if response_summary.response_type in {"image", "csv"} and hasattr(
|
||||
response_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = response_summary.tool_result.file_ids
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type=response_summary.response_type,
|
||||
data=response_summary.tool_result,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="generic_internal_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
|
||||
if generic_internal_tool is None:
|
||||
raise ValueError("generic_internal_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=generic_internal_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[generic_internal_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
f"Description: {generic_internal_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
if generic_internal_tool.display_name == "Okta Profile":
|
||||
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
else:
|
||||
tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
|
||||
tool_summary_prompt = tool_prompt.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=generic_internal_tool.llm_name,
|
||||
tool_id=generic_internal_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="text", # TODO: convert all response types to enums
|
||||
data=answer_string,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
generic_internal_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
generic_internal_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
generic_internal_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", generic_internal_tool_branch)
|
||||
|
||||
graph.add_node("act", generic_internal_tool_act)
|
||||
|
||||
graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a image generation as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
|
||||
# tell frontend that we are starting the image generation tool
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,189 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import GeneratedImage
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
state.assistant_system_prompt
|
||||
state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.prompt_builder.raw_user_query
|
||||
graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolHeartbeat(),
|
||||
writer,
|
||||
)
|
||||
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
image_generation_responses = response
|
||||
break
|
||||
|
||||
# save images to file store
|
||||
file_ids = save_files(
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
|
||||
final_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Create answer string describing the generated images
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=image_tool_info.llm_path,
|
||||
tool_id=image_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning=reasoning,
|
||||
generated_images=final_generated_images,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="generating",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import GeneratedImage
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
generated_images: list[GeneratedImage] = []
|
||||
for update in new_updates:
|
||||
if update.generated_images:
|
||||
generated_images.extend(update.generated_images)
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolDelta(
|
||||
images=generated_images,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_1_branch import (
|
||||
image_generation_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_2_act import (
|
||||
image_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_image_generation_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Image Generation Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", image_generation_branch)
|
||||
|
||||
graph.add_node("act", image_generation)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
class GeneratedImageFullResult(BaseModel):
|
||||
images: list[GeneratedImage]
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.exploration.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
state.current_step_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
logger.debug(
|
||||
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
kg_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
kb_graph = kb_graph_builder().compile()
|
||||
|
||||
kb_results = kb_graph.invoke(
|
||||
input=KbMainInput(question=search_query, individual_flow=False),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = kb_results.get("final_answer") or "No answer provided"
|
||||
claims: list[str] = []
|
||||
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
# if citation is empty, the answer must have come from the KG rather than a doc
|
||||
# in that case, simply cite the docs returned by the KG
|
||||
if not citation_numbers:
|
||||
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
if citation_number <= len(retrieved_docs)
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=kg_tool_info.llm_path,
|
||||
tool_id=kg_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=None,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.exploration.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
|
||||
|
||||
|
||||
def kg_search_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
queries = [update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
|
||||
kg_answer = (
|
||||
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
|
||||
if len(queries) == 1
|
||||
else None
|
||||
)
|
||||
|
||||
if len(retrieved_search_docs) > 0:
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
if kg_answer is not None:
|
||||
|
||||
kg_display_answer = (
|
||||
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
|
||||
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
|
||||
else kg_answer
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningDelta(reasoning=kg_display_answer),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel search for now
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_1_branch import (
|
||||
kg_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_2_act import (
|
||||
kg_search,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_3_reduce import (
|
||||
kg_search_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_kg_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for KG Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", kg_search_branch)
|
||||
|
||||
graph.add_node("act", kg_search)
|
||||
|
||||
graph.add_node("reducer", kg_search_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,46 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
|
||||
class SubAgentUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
current_step_nr: int = 1
|
||||
|
||||
|
||||
class BranchUpdate(LoggerUpdate):
|
||||
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class SubAgentInput(LoggerUpdate):
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
query_list: list[str] = []
|
||||
context: str | None = None
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
|
||||
|
||||
class SubAgentMainState(
|
||||
# This includes the core state
|
||||
SubAgentInput,
|
||||
SubAgentUpdate,
|
||||
BranchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class BranchInput(SubAgentInput):
|
||||
parallelization_nr: int = 0
|
||||
branch_question: str
|
||||
|
||||
|
||||
class CustomToolBranchInput(LoggerUpdate):
|
||||
tool_info: OrchestratorTool
|
||||
@@ -0,0 +1,71 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="auto",
|
||||
highlights=HighlightsContentsOptions(
|
||||
num_sentences=2,
|
||||
highlights_per_url=1,
|
||||
),
|
||||
num_results=10,
|
||||
)
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
author=result.author,
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=list(urls),
|
||||
text=True,
|
||||
livecrawl="preferred",
|
||||
)
|
||||
|
||||
return [
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
for result in response.results
|
||||
]
|
||||
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_SEARCH_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_CONTENTS_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Response only guarantees text
|
||||
text = response_json["text"]
|
||||
|
||||
# metadata & jsonld is not guaranteed to be present
|
||||
metadata = response_json.get("metadata", {})
|
||||
jsonld = response_json.get("jsonld", {})
|
||||
|
||||
title = extract_title_from_metadata(metadata)
|
||||
|
||||
# Serper does not provide a reliable mechanism to extract the url
|
||||
response_url = url
|
||||
published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
published_date = None
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return WebContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
|
||||
|
||||
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
keys = ["title", "og:title"]
|
||||
return extract_value_from_dict(metadata, keys)
|
||||
|
||||
|
||||
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
keys = ["dateModified"]
|
||||
return extract_value_from_dict(jsonld, keys)
|
||||
|
||||
|
||||
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a web search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=True,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,128 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
InternetSearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def web_search(
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchUpdate:
|
||||
"""
|
||||
LangGraph node to perform internet search and decide which URLs to fetch.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
if not current_step_nr:
|
||||
raise ValueError("Current step number is not set. This should not happen.")
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
search_query = state.branch_question
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[search_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_provider()
|
||||
if not provider:
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = list(provider.search(search_query))
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
search_results_text = "\n\n".join(
|
||||
[
|
||||
f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
+ (f" Author: {result.author}\n" if result.author else "")
|
||||
+ (
|
||||
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
|
||||
if result.published_date
|
||||
else ""
|
||||
)
|
||||
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
|
||||
for i, result in enumerate(search_results)
|
||||
]
|
||||
)
|
||||
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
|
||||
search_query=search_query,
|
||||
base_question=base_question,
|
||||
search_results_text=search_results_text,
|
||||
)
|
||||
agent_decision = invoke_llm_json(
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
agent_decision_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=WebSearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
results_to_open = [
|
||||
(search_query, search_results[i])
|
||||
for i in agent_decision.urls_to_open_indices
|
||||
if i < len(search_results) and i >= 0
|
||||
]
|
||||
return InternetSearchUpdate(
|
||||
results_to_open=results_to_open,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_search_result,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
|
||||
|
||||
def dedup_urls(
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
unique_results_by_link[result.link] = result
|
||||
|
||||
unique_results = list(unique_results_by_link.values())
|
||||
dummy_docs_inference_sections = [
|
||||
dummy_inference_section_from_internet_search_result(doc)
|
||||
for doc in unique_results
|
||||
]
|
||||
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
dummy_docs_inference_sections, is_internet=True
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return InternetSearchInput(
|
||||
results_to_open=[],
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
FetchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def web_fetch(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> FetchUpdate:
|
||||
"""
|
||||
LangGraph node to fetch content from URLs and process the results.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_provider()
|
||||
if provider is None:
|
||||
raise ValueError("No web search provider found")
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
try:
|
||||
retrieved_docs = [
|
||||
dummy_inference_section_from_internet_content(result)
|
||||
for result in provider.contents(state.urls_to_open)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching URLs: {e}")
|
||||
|
||||
if not retrieved_docs:
|
||||
logger.warning("No content retrieved from URLs")
|
||||
|
||||
return FetchUpdate(
|
||||
raw_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="fetching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
|
||||
|
||||
def collect_raw_docs(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
raw_documents = state.raw_documents
|
||||
|
||||
return InternetSearchInput(
|
||||
raw_documents=raw_documents,
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.models import SearchAnswer
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
SummarizeInput,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import normalize_url
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_summarize(
|
||||
state: SummarizeInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
# build branch iterations from fetch inputs
|
||||
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
|
||||
url_to_raw_document: dict[str, InferenceSection] = {}
|
||||
for raw_document in state.raw_documents:
|
||||
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
|
||||
url_to_raw_document[normalized_url] = raw_document
|
||||
|
||||
# Normalize the URLs from branch_questions_to_urls as well
|
||||
urls = [
|
||||
normalize_url(url)
|
||||
for url in state.branch_questions_to_urls[state.branch_question]
|
||||
]
|
||||
current_iteration = state.iteration_nr
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
research_type = graph_config.behavior.research_type
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
document_texts = _create_document_texts(cited_raw_documents)
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=state.branch_question,
|
||||
base_question=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
document_text=document_texts,
|
||||
)
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: cited_raw_documents[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
reasoning = ""
|
||||
claims = []
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=current_iteration,
|
||||
parallelization_nr=0,
|
||||
question=state.branch_question,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="summarizing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
|
||||
document_texts_list = []
|
||||
for doc_num, retrieved_doc in enumerate(raw_documents):
|
||||
if not isinstance(retrieved_doc, InferenceSection):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
return "\n\n".join(document_texts_list)
|
||||
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,81 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
|
||||
SummarizeInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"search",
|
||||
InternetSearchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
current_step_nr=state.current_step_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
query_list=[query],
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
results_to_open=[],
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"fetch",
|
||||
FetchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
urls_to_open=[url],
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
raw_documents=state.raw_documents,
|
||||
),
|
||||
)
|
||||
for url in set(
|
||||
url for urls in branch_questions_to_urls.values() for url in urls
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"summarize",
|
||||
SummarizeInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
raw_documents=state.raw_documents,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
branch_question=branch_question,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
),
|
||||
)
|
||||
for branch_question in branch_questions_to_urls.keys()
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_2_search import (
|
||||
web_search,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_3_dedup_urls import (
|
||||
dedup_urls,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_4_fetch import (
|
||||
web_fetch,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
|
||||
collect_raw_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_6_summarize import (
|
||||
is_summarize,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_7_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
fetch_router,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
summarize_router,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_ws_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("search", web_search)
|
||||
|
||||
graph.add_node("dedup_urls", dedup_urls)
|
||||
|
||||
graph.add_node("fetch", web_fetch)
|
||||
|
||||
graph.add_node("collect_raw_docs", collect_raw_docs)
|
||||
|
||||
graph.add_node("summarize", is_summarize)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="search", end_key="dedup_urls")
|
||||
|
||||
graph.add_conditional_edges("dedup_urls", fetch_router)
|
||||
|
||||
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
|
||||
|
||||
graph.add_conditional_edges("collect_raw_docs", summarize_router)
|
||||
|
||||
graph.add_edge(start_key="summarize", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,53 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.utils.url import normalize_url
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Enum for internet search provider types"""
|
||||
|
||||
GOOGLE = "google"
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
snippet: str | None = None
|
||||
author: str | None = None
|
||||
published_date: datetime | None = None
|
||||
|
||||
@field_validator("link")
|
||||
@classmethod
|
||||
def normalize_link(cls, v: str) -> str:
|
||||
return normalize_url(v)
|
||||
|
||||
|
||||
class WebContent(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: datetime | None = None
|
||||
scrape_successful: bool = True
|
||||
|
||||
@field_validator("link")
|
||||
@classmethod
|
||||
def normalize_link(cls, v: str) -> str:
|
||||
return normalize_url(v)
|
||||
|
||||
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
return SerperClient()
|
||||
return None
|
||||
@@ -0,0 +1,37 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.exploration.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
urls_to_open: Annotated[list[str], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class FetchUpdate(LoggerUpdate):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class SummarizeInput(SubAgentInput):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
branch_question: str
|
||||
@@ -0,0 +1,99 @@
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
"""Truncate search result content to a maximum number of characters"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
return content[:max_chars] + "..."
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: WebContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content=truncated_content,
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=(not result.scrape_successful),
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
chunk_context=truncated_content,
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content=truncated_content,
|
||||
)
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: WebSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content="",
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
|
||||
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
|
||||
return LlmDoc(
|
||||
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
|
||||
# should ideally correspond to a document in the database. But I guess if you're calling this
|
||||
# function you know it won't be in the database.
|
||||
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
|
||||
content=truncate_search_result_content(web_content.full_content),
|
||||
blurb=web_content.link,
|
||||
semantic_identifier=web_content.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
link=web_content.link,
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
updated_at=web_content.published_date,
|
||||
source_links={},
|
||||
match_highlights=[],
|
||||
)
|
||||
277
backend/onyx/agents/agent_search/exploration/utils.py
Normal file
277
backend/onyx/agents/agent_search/exploration/utils.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.exploration.models import IterationAnswer
|
||||
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
|
||||
|
||||
CITATION_PREFIX = "CITE:"
|
||||
|
||||
|
||||
def extract_document_citations(
|
||||
answer: str, claims: list[str]
|
||||
) -> tuple[list[int], str, list[str]]:
|
||||
"""
|
||||
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
|
||||
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
|
||||
etc., to help with citation deduplication later on.
|
||||
"""
|
||||
citations: set[int] = set()
|
||||
|
||||
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
|
||||
# This regex matches:
|
||||
# - \[(\d+)\] for single citations like [1]
|
||||
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
|
||||
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||
|
||||
def _extract_and_replace(match: re.Match[str]) -> str:
|
||||
numbers = [int(num) for num in match.group(1).split(",")]
|
||||
citations.update(numbers)
|
||||
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
|
||||
|
||||
new_answer = pattern.sub(_extract_and_replace, answer)
|
||||
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
|
||||
|
||||
return list(citations), new_answer, new_claims
|
||||
|
||||
|
||||
def aggregate_context(
|
||||
iteration_responses: list[IterationAnswer], include_documents: bool = True
|
||||
) -> AggregatedDRContext:
|
||||
"""
|
||||
Converts the iteration response into a single string with unified citations.
|
||||
For example,
|
||||
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
|
||||
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
|
||||
Output:
|
||||
it 1: the answer is x [1][2].
|
||||
it 2: blah blah [2][3]
|
||||
[1]: doc_xyz
|
||||
[2]: doc_abc
|
||||
[3]: doc_pqr
|
||||
"""
|
||||
# dedupe and merge inference section contents
|
||||
unrolled_inference_sections: list[InferenceSection] = []
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
|
||||
iteration_tool = iteration_response.tool
|
||||
is_internet = iteration_tool == WebSearchTool._NAME
|
||||
|
||||
for cited_doc in iteration_response.cited_documents.values():
|
||||
unrolled_inference_sections.append(cited_doc)
|
||||
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
|
||||
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
|
||||
is_internet
|
||||
)
|
||||
cited_doc.center_chunk.score = None # None means maintain order
|
||||
|
||||
global_documents = dedup_inference_section_list(unrolled_inference_sections)
|
||||
|
||||
global_citations = {
|
||||
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
|
||||
}
|
||||
|
||||
# build output string
|
||||
output_strings: list[str] = []
|
||||
global_iteration_responses: list[IterationAnswer] = []
|
||||
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
# add basic iteration info
|
||||
output_strings.append(
|
||||
f"Iteration: {iteration_response.iteration_nr}, "
|
||||
f"Question {iteration_response.parallelization_nr}"
|
||||
)
|
||||
output_strings.append(f"Tool: {iteration_response.tool}")
|
||||
output_strings.append(f"Question: {iteration_response.question}")
|
||||
|
||||
# get answer and claims with global citations
|
||||
answer_str = iteration_response.answer
|
||||
claims = iteration_response.claims or []
|
||||
|
||||
iteration_citations: list[int] = []
|
||||
for local_number, cited_doc in iteration_response.cited_documents.items():
|
||||
global_number = global_citations[cited_doc.center_chunk.document_id]
|
||||
# translate local citations to global citations
|
||||
answer_str = answer_str.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
claims = [
|
||||
claim.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
iteration_citations.append(global_number)
|
||||
|
||||
# add answer, claims, and citation info
|
||||
if answer_str:
|
||||
output_strings.append(f"Answer: {answer_str}")
|
||||
if claims:
|
||||
output_strings.append(
|
||||
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
|
||||
or "No claims provided"
|
||||
)
|
||||
if not answer_str and not claims:
|
||||
output_strings.append(
|
||||
"Retrieved documents: "
|
||||
+ (
|
||||
"".join(
|
||||
f"[{global_number}]"
|
||||
for global_number in sorted(iteration_citations)
|
||||
)
|
||||
or "No documents retrieved"
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
# save global iteration response
|
||||
iteration_response_copy = iteration_response.model_copy()
|
||||
iteration_response_copy.answer = answer_str
|
||||
iteration_response_copy.claims = claims
|
||||
iteration_response_copy.cited_documents = {
|
||||
global_citations[doc.center_chunk.document_id]: doc
|
||||
for doc in iteration_response.cited_documents.values()
|
||||
}
|
||||
global_iteration_responses.append(iteration_response_copy)
|
||||
|
||||
# add document contents if requested
|
||||
if include_documents:
|
||||
if global_documents:
|
||||
output_strings.append("Cited document contents:")
|
||||
for doc in global_documents:
|
||||
output_strings.append(
|
||||
build_document_context(
|
||||
doc, global_citations[doc.center_chunk.document_id]
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
|
||||
return AggregatedDRContext(
|
||||
context="\n".join(output_strings),
|
||||
cited_documents=global_documents,
|
||||
is_internet_marker_dict=is_internet_marker_dict,
|
||||
global_iteration_responses=global_iteration_responses,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
|
||||
"""
|
||||
Get the chat history (up to max_messages) as a string.
|
||||
"""
|
||||
# get past max_messages USER, ASSISTANT message pairs
|
||||
|
||||
past_messages = chat_history[-max_messages * 2 :]
|
||||
filtered_past_messages = copy.deepcopy(past_messages)
|
||||
|
||||
for past_message_number, past_message in enumerate(past_messages):
|
||||
|
||||
if isinstance(past_message.content, list):
|
||||
removal_indices = []
|
||||
for content_piece_number, content_piece in enumerate(past_message.content):
|
||||
if (
|
||||
isinstance(content_piece, dict)
|
||||
and content_piece.get("type") != "text"
|
||||
):
|
||||
removal_indices.append(content_piece_number)
|
||||
|
||||
# Only rebuild the content list if there are items to remove
|
||||
if removal_indices:
|
||||
filtered_past_messages[past_message_number].content = [
|
||||
content_piece
|
||||
for content_piece_number, content_piece in enumerate(
|
||||
past_message.content
|
||||
)
|
||||
if content_piece_number not in removal_indices
|
||||
]
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
return (
|
||||
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
|
||||
) + "\n".join(
|
||||
("user" if isinstance(msg, HumanMessage) else "you")
|
||||
+ f": {str(msg.content).strip()}"
|
||||
for msg in filtered_past_messages
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_question(
|
||||
question: str, clarification: OrchestrationClarificationInfo | None
|
||||
) -> str:
|
||||
if clarification:
|
||||
clarification_question = clarification.clarification_question
|
||||
clarification_response = clarification.clarification_response
|
||||
return (
|
||||
f"Initial User Question: {question}\n"
|
||||
f"(Clarification Question: {clarification_question}\n"
|
||||
f"User Response: {clarification_response})"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
|
||||
"""
|
||||
Create a string representation of the tool call.
|
||||
"""
|
||||
questions_str = "\n - ".join(query_list)
|
||||
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
|
||||
|
||||
|
||||
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
|
||||
# Convert plan string to numbered dict format
|
||||
if not plan_text:
|
||||
return {}
|
||||
|
||||
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
|
||||
parts = re.split(r"(\d+[.)])", plan_text)
|
||||
plan_dict = {}
|
||||
|
||||
for i in range(
|
||||
1, len(parts), 2
|
||||
): # Skip empty first part, then take number and text pairs
|
||||
if i + 1 < len(parts):
|
||||
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
|
||||
text = parts[i + 1].strip()
|
||||
if text: # Only add if there's actual content
|
||||
plan_dict[number] = text
|
||||
|
||||
return plan_dict
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
inference_sections: list[InferenceSection],
|
||||
is_internet: bool = False,
|
||||
) -> list[SavedSearchDoc]:
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
|
||||
for search_doc in search_docs:
|
||||
search_doc.is_internet = is_internet
|
||||
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
return retrieved_saved_search_docs
|
||||
@@ -12,6 +12,10 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
|
||||
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
|
||||
from onyx.agents.agent_search.exploration.graph_builder import exploration_graph_builder
|
||||
from onyx.agents.agent_search.exploration.states import (
|
||||
MainInput as ExplorationMainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -81,6 +85,16 @@ def run_dr_graph(
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_exploration_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = exploration_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = ExplorationMainInput(log_messages=[])
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TypeVar
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
@@ -183,6 +184,32 @@ def invoke_llm_json(
|
||||
return schema.model_validate_json(response_content)
|
||||
|
||||
|
||||
# FOR EXPERIMENTAL USE ONLY
|
||||
def invoke_llm_raw(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> AIMessage:
|
||||
"""
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
|
||||
response_content = llm.invoke_langchain(
|
||||
prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
**cast(dict, {}),
|
||||
)
|
||||
|
||||
return cast(AIMessage, response_content)
|
||||
|
||||
|
||||
def get_answer_from_llm(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
|
||||
@@ -5,12 +5,13 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.exploration.enums import ResearchType as ExpResearchType
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_dr_graph
|
||||
from onyx.agents.agent_search.run_graph import run_exploration_graph
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
@@ -32,6 +33,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import fast_gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import EXPLORATION_TEST_TYPE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -60,7 +62,7 @@ class Answer:
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
research_type: ResearchType | None = None,
|
||||
research_type: ResearchType | ExpResearchType | None = None,
|
||||
research_plan: dict[str, Any] | None = None,
|
||||
project_instructions: str | None = None,
|
||||
) -> None:
|
||||
@@ -121,6 +123,9 @@ class Answer:
|
||||
else:
|
||||
research_type = ResearchType.THOUGHTFUL
|
||||
|
||||
if EXPLORATION_TEST_TYPE != "not_explicitly_set":
|
||||
research_type = ExpResearchType(research_type)
|
||||
|
||||
self.search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
@@ -144,7 +149,8 @@ class Answer:
|
||||
return
|
||||
|
||||
# TODO: add toggle in UI with customizable TimeBudget
|
||||
stream = run_dr_graph(self.graph_config)
|
||||
# stream = run_dr_graph(self.graph_config)
|
||||
stream = run_exploration_graph(self.graph_config)
|
||||
|
||||
processed_stream: list[AnswerStreamPart] = []
|
||||
for packet in stream:
|
||||
|
||||
20
backend/onyx/configs/exploration_research_configs.py
Normal file
20
backend/onyx/configs/exploration_research_configs.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
EXPLORATION_TEST_USE_DC_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_DC_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
EXPLORATION_TEST_USE_CALRIFIER_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_CALRIFIER_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
EXPLORATION_TEST_USE_PLAN_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_PLAN_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
EXPLORATION_TEST_USE_THINKING_DEFAULT = (
|
||||
os.environ.get("EXPLORATION_TEST_USE_THINKING_DEFAULT") or "false"
|
||||
).lower() == "true"
|
||||
85
backend/onyx/db/hackathon_subscriptions.py
Normal file
85
backend/onyx/db/hackathon_subscriptions.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import SubscriptionRegistration
|
||||
from onyx.db.models import SubscriptionResult
|
||||
|
||||
|
||||
def get_subscription_registration(
|
||||
db_session: Session, user_id: str
|
||||
) -> SubscriptionRegistration:
|
||||
return (
|
||||
db_session.query(SubscriptionRegistration)
|
||||
.filter(SubscriptionRegistration.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_subscription_result(db_session: Session, user_id: UUID) -> SubscriptionResult:
|
||||
return (
|
||||
db_session.query(SubscriptionResult)
|
||||
.filter(SubscriptionResult.user_id == user_id)
|
||||
.order_by(SubscriptionResult.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def save_subscription_result(
|
||||
db_session: Session, subscription_result: SubscriptionResult
|
||||
) -> None:
|
||||
db_session.add(subscription_result)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_document_ids_by_cc_pair_name(
|
||||
db_session: Session,
|
||||
cc_pair_name: str,
|
||||
date_threshold: datetime | None = None,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""
|
||||
Get all document IDs and links associated with a connector credential pair by its name,
|
||||
optionally filtered by documents updated after a date threshold.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
cc_pair_name: Name of the connector credential pair
|
||||
date_threshold: Optional datetime to filter documents updated after this date
|
||||
|
||||
Returns:
|
||||
List of tuples containing (document_id, document_link)
|
||||
"""
|
||||
# First, get the connector_id and credential_id from the ConnectorCredentialPair by name
|
||||
cc_pair = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.name == cc_pair_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
return []
|
||||
|
||||
# Build query to get document IDs and links associated with this connector/credential pair
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair.id, Document.link)
|
||||
.join(
|
||||
Document,
|
||||
DocumentByConnectorCredentialPair.id == Document.id,
|
||||
)
|
||||
.where(
|
||||
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Add date threshold filter if provided
|
||||
if date_threshold is not None:
|
||||
stmt = stmt.where(Document.doc_updated_at >= date_threshold)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return [(doc_id, link) for doc_id, link in results]
|
||||
@@ -194,6 +194,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
cheat_sheet_context: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
@@ -3952,3 +3955,72 @@ class ExternalGroupPermissionSyncAttempt(Base):
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
|
||||
# EXPLORATION TESTING
|
||||
class TemporaryUserCheatSheetContext(Base):
|
||||
"""
|
||||
Represents the context of a user's cheat sheet. Replace with column in user table once
|
||||
login is working again.
|
||||
"""
|
||||
|
||||
__tablename__ = "temporary_user_cheat_sheet_context"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
context: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
# HACKATHON CHANGES
|
||||
|
||||
|
||||
class SubscriptionRegistration(Base):
|
||||
"""
|
||||
Represents a user's subscription registration with document extraction contexts
|
||||
and search questions.
|
||||
"""
|
||||
|
||||
__tablename__ = "subscription_registrations"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
doc_extraction_contexts: Mapped[dict[str, str]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
search_questions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionResult(Base):
|
||||
"""
|
||||
Represents the results of a subscription for a user, including notifications.
|
||||
"""
|
||||
|
||||
__tablename__ = "subscription_results"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
|
||||
type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
notifications: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
32
backend/onyx/db/temp_exp.py
Normal file
32
backend/onyx/db/temp_exp.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import TemporaryUserCheatSheetContext
|
||||
|
||||
# EXPLORATION TESTING
|
||||
|
||||
|
||||
def get_user_cheat_sheet_context(db_session: Session) -> dict[str, Any] | None:
|
||||
stmt = select(TemporaryUserCheatSheetContext).order_by(
|
||||
TemporaryUserCheatSheetContext.created_at.desc()
|
||||
)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result.context if result else None
|
||||
|
||||
|
||||
def update_user_cheat_sheet_context(
|
||||
db_session: Session, new_cheat_sheet_context: dict[str, Any]
|
||||
) -> None:
|
||||
stmt = select(TemporaryUserCheatSheetContext).order_by(
|
||||
TemporaryUserCheatSheetContext.created_at.desc()
|
||||
)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result:
|
||||
result.context = new_cheat_sheet_context
|
||||
db_session.commit()
|
||||
else:
|
||||
new_context = TemporaryUserCheatSheetContext(context=new_cheat_sheet_context)
|
||||
db_session.add(new_context)
|
||||
db_session.commit()
|
||||
@@ -342,3 +342,19 @@ def delete_user_from_db(
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
remove_user_from_invited_users(user_to_delete.email)
|
||||
|
||||
|
||||
# EXPLORATION TESTING
|
||||
|
||||
|
||||
def get_user_cheat_sheet_context(
|
||||
user: User, db_session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
return user.cheat_sheet_context
|
||||
|
||||
|
||||
def update_user_cheat_sheet_context(
|
||||
user: User, new_cheat_sheet_context: dict[str, Any], db_session: Session
|
||||
) -> None:
|
||||
user.cheat_sheet_context = new_cheat_sheet_context
|
||||
db_session.flush() # Make the change visible to subsequent queries in the same transaction
|
||||
|
||||
@@ -207,6 +207,8 @@ def construct_tools(
|
||||
continue
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
if db_tool_model.in_code_tool_id != "SearchTool":
|
||||
continue
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.in_code_tool_id)
|
||||
|
||||
try:
|
||||
|
||||
@@ -220,3 +220,5 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
EXPLORATION_TEST_TYPE = os.environ.get("EXPLORATION_TEST_TYPE") or "not_explicitly_set"
|
||||
|
||||
Reference in New Issue
Block a user